diff --git a/ggml-metal.m b/ggml-metal.m index a08abbc29..336f8e740 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute( { //GGML_ASSERT(ne00 == ne10); //GGML_ASSERT(ne03 == ne13); - - GGML_ASSERT(src0t == GGML_TYPE_I32); - - const int n_as = ((int32_t *) dst->op_params)[1]; - - // TODO: make this more general - GGML_ASSERT(n_as <= 8); + const int n_as = src0->ne[2]; // max size of the src1ids array in the kernel shared buffer GGML_ASSERT(ne11 <= 4096); - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); + // src2 = ids + const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20); + const int64_t ne21 = src2->ne[1]; + const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); + const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); + const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); + const uint64_t nb21 = src2->nb[1]; + const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); + const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - GGML_ASSERT(!ggml_is_transposed(src2)); + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(src1t == GGML_TYPE_F32); - const uint r2 = ne12/ne22; - const uint r3 = ne13/ne23; - // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel int ne11_mm_min = n_as; @@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute( const int idx = ((int32_t *) dst->op_params)[0]; // batch size - GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(ne21 == ne11); // ? + GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting + const uint r2 = 1; + const uint r3 = 1; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1732,20 +1729,20 @@ static enum ggml_status ggml_metal_graph_compute( // indirect matrix multiplication // !!! if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne20 % 32 == 0 && ne20 >= 64 && + ne00 % 32 == 0 && ne00 >= 64 && ne11 > ne11_mm_min) { // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + switch (src2->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb21 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb21 % 8 == 0); break; default: break; } id pipeline = nil; - switch (src2->type) { + switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; @@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < 8; ++j) { - // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 - struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; - - size_t offs_src_cur = 0; - id id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; - } + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:19]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; // use custom matrix x vector kernel - switch (src2t) { + switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); @@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute( } }; - if (ggml_is_quantized(src2t)) { - GGML_ASSERT(ne20 >= nth0*nth1); + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); } const int64_t _ne1 = 1; // kernels needs a reference in constant memory @@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < 8; ++j) { - // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 - struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:21]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:22]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:23]; - size_t offs_src_cur = 0; - id id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - - if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || - src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K || - src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) { - const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) { - const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) { + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q3_K) { + else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } - else if (src2t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { const int64_t ny = (_ne1 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 744b2a8b4..a876af365 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5785,9 +5785,10 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -5804,22 +5805,14 @@ kernel void kernel_mul_mm_id( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; // expert id const int32_t id = tgpig.z/(ne12*ne13); + device const uchar * src0 = src0s + id*nb02; tgpig.z = tgpig.z%(ne12*ne13); @@ -5834,7 +5827,7 @@ kernel void kernel_mul_mm_id( } kernel_mul_mm_id_impl( - src0s[id], + src0, src1, src1ids, dst, @@ -5960,9 +5953,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m // typedef void (mat_mm_id_t)( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -5979,14 +5973,6 @@ typedef void (mat_mm_id_t)( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar *, uint3, uint, uint); @@ -6022,9 +6008,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel [[host_name("kernel_mul_mv_id_f32_f32")]] kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6045,28 +6032,19 @@ kernel void kernel_mul_mv_id_f32_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f32_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6091,9 +6069,10 @@ kernel void kernel_mul_mv_id_f32_f32( [[host_name("kernel_mul_mv_id_f16_f32")]] kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6114,28 +6093,19 @@ kernel void kernel_mul_mv_id_f16_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f16_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6160,9 +6130,10 @@ kernel void kernel_mul_mv_id_f16_f32( [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6183,28 +6154,19 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q8_0_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6223,9 +6185,10 @@ kernel void kernel_mul_mv_id_q8_0_f32( [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6246,28 +6209,19 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6286,9 +6240,10 @@ kernel void kernel_mul_mv_id_q4_0_f32( [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6309,28 +6264,19 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6349,9 +6295,10 @@ kernel void kernel_mul_mv_id_q4_1_f32( [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6372,28 +6319,19 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6412,9 +6350,10 @@ kernel void kernel_mul_mv_id_q5_0_f32( [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6435,28 +6374,19 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6475,9 +6405,10 @@ kernel void kernel_mul_mv_id_q5_1_f32( [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6498,28 +6429,19 @@ kernel void kernel_mul_mv_id_q2_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q2_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6538,9 +6460,10 @@ kernel void kernel_mul_mv_id_q2_K_f32( [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6561,28 +6484,19 @@ kernel void kernel_mul_mv_id_q3_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q3_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6601,9 +6515,10 @@ kernel void kernel_mul_mv_id_q3_K_f32( [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6624,28 +6539,19 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q4_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6664,9 +6570,10 @@ kernel void kernel_mul_mv_id_q4_K_f32( [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6687,28 +6594,19 @@ kernel void kernel_mul_mv_id_q5_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q5_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6727,9 +6625,10 @@ kernel void kernel_mul_mv_id_q5_K_f32( [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6750,28 +6649,19 @@ kernel void kernel_mul_mv_id_q6_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q6_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6790,9 +6680,10 @@ kernel void kernel_mul_mv_id_q6_K_f32( [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel void kernel_mul_mv_id_iq2_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6813,29 +6704,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6855,9 +6737,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel void kernel_mul_mv_id_iq2_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6878,29 +6761,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6920,9 +6794,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel void kernel_mul_mv_id_iq3_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6943,29 +6818,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6985,9 +6851,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel void kernel_mul_mv_id_iq3_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7008,29 +6875,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7050,9 +6908,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32( [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel void kernel_mul_mv_id_iq2_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7073,29 +6932,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7115,9 +6965,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32( [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel void kernel_mul_mv_id_iq1_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7138,28 +6989,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq1_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7178,9 +7020,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32( [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel void kernel_mul_mv_id_iq1_m_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7201,28 +7044,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq1_m_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7241,9 +7075,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32( [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel void kernel_mul_mv_id_iq4_nl_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7264,29 +7099,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq4_nl_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7306,9 +7132,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel void kernel_mul_mv_id_iq4_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7329,33 +7156,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; #if QK_K == 64 kernel_mul_mv_iq4_nl_f32_impl( #else kernel_mul_mv_iq4_xs_f32_impl( #endif - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, diff --git a/ggml.c b/ggml.c index eb2ea6af0..3e87bedc9 100644 --- a/ggml.c +++ b/ggml.c @@ -11049,8 +11049,7 @@ static void ggml_compute_forward_mul_mat_id( continue; } - //const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; - size_t src0_offset = src0->nb[2]*cur_a; + size_t src0_offset = cur_a*src0->nb[2]; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10);