diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 4ef1efae4..c74c28cd9 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -598,6 +598,7 @@ int llama_mtl_eval( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; if (src0t == GGML_TYPE_Q4_0) { + //printf("nb = %d\n", ne00/32); [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 5bbbfecd6..53f7f7448 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -247,8 +247,14 @@ kernel void kernel_mul_mat_q4_0_f32( device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], @@ -256,12 +262,11 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { + const int nb = ne00/QK4_0; + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int qk = QK4_0; - const int nb = ne00/qk; - device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; device const float * y = (device const float *) src1 + r1*ne10; @@ -272,7 +277,7 @@ kernel void kernel_mul_mat_q4_0_f32( for (int i = tpitg.x; i < nb; i += tptg.x) { device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; - device const float4 * y0p = (device const float4 *) (y + i*qk); + device const float4 * y0p = (device const float4 *) (y + i*QK4_0); const float d = (float)((x + i)->d); @@ -282,42 +287,12 @@ kernel void kernel_mul_mat_q4_0_f32( float acc = 0.0f; - { - const int x0 = x0v[0] & 0x0F; - const int x1 = x0v[0] >> 4; + for (int j = 0; j < 4; ++j) { + const int x0 = x0v[j] & 0x0F; + const int x1 = x0v[j] >> 4; - const float y0 = y0v[0]; - const float y1 = y1v[0]; - - acc += (x0 - 8)*y0 + (x1 - 8)*y1; - } - - { - const int x0 = x0v[1] & 0x0F; - const int x1 = x0v[1] >> 4; - - const float y0 = y0v[1]; - const float y1 = y1v[1]; - - acc += (x0 - 8)*y0 + (x1 - 8)*y1; - } - - { - const int x0 = x0v[2] & 0x0F; - const int x1 = x0v[2] >> 4; - - const float y0 = y0v[2]; - const float y1 = y1v[2]; - - acc += (x0 - 8)*y0 + (x1 - 8)*y1; - } - - { - const int x0 = x0v[3] & 0x0F; - const int x1 = x0v[3] >> 4; - - const float y0 = y0v[3]; - const float y1 = y1v[3]; + const float y0 = y0v[j]; + const float y1 = y1v[j]; acc += (x0 - 8)*y0 + (x1 - 8)*y1; }