From b2fd06c6aa00490d16ef206a76c04dfc9149e60b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 May 2023 23:06:49 +0300 Subject: [PATCH] mtl : working mul_mat q4 --- examples/mtl/mtl.m | 7 ++++--- examples/mtl/mtl.metal | 8 +++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 8985a0d74..e447dfcf6 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -389,11 +389,12 @@ int llama_mtl_eval( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1); @@ -446,7 +447,7 @@ int llama_mtl_eval( [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 7eb9259e2..0cd93df7f 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -127,6 +127,7 @@ kernel void kernel_mul_mat_q4_0( constant int64_t & ne11, constant int64_t & ne0, constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], @@ -140,10 +141,9 @@ kernel void kernel_mul_mat_q4_0( device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; device const float * y = (device const float *) src1 + r1*ne10; - threadgroup float sum[32]; // TODO: should be equal to threadgroup size sum[tpitg.x] = 0.0f; - for (int i = 0; i < nb; i += tptg.x) { + for (int i = tpitg.x; i < nb; i += tptg.x) { //device const uint4 * x0p = (device const uint4 *) (x + i)->qs; //device const float4 * y0p = (device const float4 *) (y + i*qk); @@ -206,5 +206,7 @@ kernel void kernel_mul_mat_q4_0( threadgroup_barrier(mem_flags::mem_threadgroup); } - dst[r1*ne0 + r0] = sum[0]; + if (tpitg.x == 0) { + dst[r1*ne0 + r0] = sum[0]; + } }