diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 372396047..ebd3d6235 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -480,41 +480,41 @@ int llama_mtl_eval( const int64_t ne01 = gf->nodes[i]->src0->ne[1]; const int64_t ne02 = gf->nodes[i]->src0->ne[2]; - //const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; - //const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; + const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; + const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; const int64_t ne10 = gf->nodes[i]->src1->ne[0]; const int64_t ne11 = gf->nodes[i]->src1->ne[1]; const int64_t ne12 = gf->nodes[i]->src1->ne[2]; - //const uint64_t nb10 = gf->nodes[i]->src1->nb[0]; - //const uint64_t nb11 = gf->nodes[i]->src1->nb[1]; + const uint64_t nb10 = gf->nodes[i]->src1->nb[0]; + const uint64_t nb11 = gf->nodes[i]->src1->nb[1]; const uint64_t nb12 = gf->nodes[i]->src1->nb[2]; const int64_t ne0 = gf->nodes[i]->ne[0]; const int64_t ne1 = gf->nodes[i]->ne[1]; const int64_t ne2 = gf->nodes[i]->ne[2]; - //const uint64_t nb0 = gf->nodes[i]->nb[0]; - //const uint64_t nb1 = gf->nodes[i]->nb[1]; + const uint64_t nb0 = gf->nodes[i]->nb[0]; + const uint64_t nb1 = gf->nodes[i]->nb[1]; const uint64_t nb2 = gf->nodes[i]->nb[2]; - const int nth = 16; - const enum ggml_type src0t = gf->nodes[i]->src0->type; const enum ggml_type src1t = gf->nodes[i]->src1->type; const enum ggml_type dstt = gf->nodes[i]->type; - fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02); - fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12); + fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0)); + fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1)); fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt)); GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne02 == ne12); - if ((src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { + if (ggml_is_contiguous(gf->nodes[i]->src0) && + ggml_is_contiguous(gf->nodes[i]->src1) && + (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { if (encoder != nil) { [encoder endEncoding]; encoder = nil; @@ -555,25 +555,52 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } + int nth = 32; + // use custom matrix x vector kernel switch (src0t) { - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break; + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth = 4; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; default: GGML_ASSERT(false && "not implemented"); }; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + if (src0t == GGML_TYPE_Q4_0) { + [encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)]; + } else { + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } } break; case GGML_OP_GET_ROWS: diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 1bada42dd..2272f9ff3 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -265,7 +265,10 @@ kernel void kernel_mul_mat_q4_0_f32( device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; device const float * y = (device const float *) src1 + r1*ne10; - sum[tpitg.x] = 0.0f; + const uint nth = tptg.x*tptg.y; + const uint ith = 16*tpitg.x + tpitg.y; + + sum[ith] = 0.0f; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uchar * x0p = (device const uchar *) (x + i)->qs; @@ -273,7 +276,9 @@ kernel void kernel_mul_mat_q4_0_f32( float acc = 0.0f; - for (int j = 0; j < 16; ++j) { + //for (int j = 0; j < 16; ++j) { + const int j = tpitg.y; + { const uchar x0v = *(x0p + j); const int x0 = x0v & 0x0F; @@ -285,43 +290,50 @@ kernel void kernel_mul_mat_q4_0_f32( acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sum[tpitg.x] += acc * (x + i)->d; + sum[ith] += acc * (x + i)->d; } // accumulate the sum from all threads in the threadgroup threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = tptg.x/2; i > 0; i /= 2) { - if (tpitg.x < i) { - sum[tpitg.x] += sum[tpitg.x + i]; + for (uint i = nth/2; i > 0; i /= 2) { + if (ith < i) { + sum[ith] += sum[ith + i]; } threadgroup_barrier(mem_flags::mem_threadgroup); } - if (tpitg.x == 0) { + if (ith == 0) { dst[r1*ne0 + r0] = sum[0]; } } kernel void kernel_mul_mat_f16_f32( - device const half * src0, - device const float * src1, + device const char * src0, + device const char * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpig[[thread_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; - device const half * x = src0 + r0*ne00; - device const float * y = src1 + r1*ne10; + device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); sum[tpitg.x] = 0.0f; @@ -339,7 +351,7 @@ kernel void kernel_mul_mat_f16_f32( } if (tpitg.x == 0) { - dst[r1*ne0 + r0] = sum[0]; + dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; } } diff --git a/ggml.c b/ggml.c index 1c9bb4e61..b5e6997dd 100644 --- a/ggml.c +++ b/ggml.c @@ -3821,11 +3821,11 @@ size_t ggml_tensor_overhead(void) { return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16; } -static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) { +bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } -static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return diff --git a/ggml.h b/ggml.h index 1f033b492..7f821cf32 100644 --- a/ggml.h +++ b/ggml.h @@ -442,6 +442,9 @@ extern "C" { // TODO: temporary until model loading of ggml examples is refactored GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void);