diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index ebd3d6235..4ef1efae4 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -555,7 +555,8 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - int nth = 32; + int nth0 = 32; + int nth1 = 1; // use custom matrix x vector kernel switch (src0t) { @@ -564,14 +565,16 @@ int llama_mtl_eval( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth = 4; + nth0 = 8; + nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; case GGML_TYPE_F16: { GGML_ASSERT(ne02 == ne12); - nth = 32; + nth0 = 32; + nth1 = 1; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; } break; default: GGML_ASSERT(false && "not implemented"); @@ -595,11 +598,11 @@ int llama_mtl_eval( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; if (src0t == GGML_TYPE_Q4_0) { - [encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)]; + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 2272f9ff3..5bbbfecd6 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -266,31 +266,63 @@ kernel void kernel_mul_mat_q4_0_f32( device const float * y = (device const float *) src1 + r1*ne10; const uint nth = tptg.x*tptg.y; - const uint ith = 16*tpitg.x + tpitg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; sum[ith] = 0.0f; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uchar * x0p = (device const uchar *) (x + i)->qs; - device const float * y0p = (device const float *) (y + i*qk); + device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; + device const float4 * y0p = (device const float4 *) (y + i*qk); + + const float d = (float)((x + i)->d); + + const uchar4 x0v = *(x0p + tpitg.y); + const float4 y0v = *(y0p + tpitg.y + 0); + const float4 y1v = *(y0p + tpitg.y + 4); float acc = 0.0f; - //for (int j = 0; j < 16; ++j) { - const int j = tpitg.y; { - const uchar x0v = *(x0p + j); + const int x0 = x0v[0] & 0x0F; + const int x1 = x0v[0] >> 4; - const int x0 = x0v & 0x0F; - const int x1 = x0v >> 4; - - const float y0 = *(y0p + j); - const float y1 = *(y0p + j + 16); + const float y0 = y0v[0]; + const float y1 = y1v[0]; acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sum[ith] += acc * (x + i)->d; + { + const int x0 = x0v[1] & 0x0F; + const int x1 = x0v[1] >> 4; + + const float y0 = y0v[1]; + const float y1 = y1v[1]; + + acc += (x0 - 8)*y0 + (x1 - 8)*y1; + } + + { + const int x0 = x0v[2] & 0x0F; + const int x1 = x0v[2] >> 4; + + const float y0 = y0v[2]; + const float y1 = y1v[2]; + + acc += (x0 - 8)*y0 + (x1 - 8)*y1; + } + + { + const int x0 = x0v[3] & 0x0F; + const int x1 = x0v[3] >> 4; + + const float y0 = y0v[3]; + const float y1 = y1v[3]; + + acc += (x0 - 8)*y0 + (x1 - 8)*y1; + } + + sum[ith] += acc*d; } // accumulate the sum from all threads in the threadgroup