diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index ff1adf6df..372396047 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -52,8 +52,11 @@ struct ggml_mtl_context { id function_rms_norm; id pipeline_rms_norm; - id function_mul_mat_q4_0; - id pipeline_mul_mat_q4_0; + id function_mul_mat_q4_0_f32; + id pipeline_mul_mat_q4_0_f32; + + id function_mul_mat_f16_f32; + id pipeline_mul_mat_f16_f32; id function_rope; id pipeline_rope; @@ -183,9 +186,13 @@ struct ggml_mtl_context * llama_mtl_init( ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil]; fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm); - ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"]; - ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil]; - fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0); + ctx->function_mul_mat_q4_0_f32 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0_f32"]; + ctx->pipeline_mul_mat_q4_0_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0_f32 error:nil]; + fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0_f32: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0_f32); + + ctx->function_mul_mat_f16_f32 = [ctx->library newFunctionWithName:@"kernel_mul_mat_f16_f32"]; + ctx->pipeline_mul_mat_f16_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_f16_f32 error:nil]; + fprintf(stderr, "%s: loaded kernel_mul_mat_f16_f32: %p\n", __func__, (void *) ctx->pipeline_mul_mat_f16_f32); ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"]; ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil]; @@ -493,6 +500,8 @@ int llama_mtl_eval( //const uint64_t nb1 = gf->nodes[i]->nb[1]; const uint64_t nb2 = gf->nodes[i]->nb[2]; + const int nth = 16; + const enum ggml_type src0t = gf->nodes[i]->src0->type; const enum ggml_type src1t = gf->nodes[i]->src1->type; const enum ggml_type dstt = gf->nodes[i]->type; @@ -505,7 +514,7 @@ int llama_mtl_eval( GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne02 == ne12); - if (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) { + if ((src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { if (encoder != nil) { [encoder endEncoding]; encoder = nil; @@ -528,6 +537,8 @@ int llama_mtl_eval( initWithDevice:ctx->device transposeLeft:false transposeRight:true resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; + // we need to do ne02 multiplications + // TODO: is there a way to do this in parallel - currently very slow .. for (int64_t i02 = 0; i02 < ne02; ++i02) { size_t offs_src0_cur = offs_src0 + i02*nb02; size_t offs_src1_cur = offs_src1 + i02*nb12; @@ -544,8 +555,13 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - // for Q4 x F32 we use custom kernel - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0]; + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break; + default: GGML_ASSERT(false && "not implemented"); + }; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -555,9 +571,9 @@ int llama_mtl_eval( [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index f8446d17f..1bada42dd 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -241,7 +241,7 @@ kernel void kernel_rms_norm( } } -kernel void kernel_mul_mat_q4_0( +kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -268,39 +268,6 @@ kernel void kernel_mul_mat_q4_0( sum[tpitg.x] = 0.0f; for (int i = tpitg.x; i < nb; i += tptg.x) { - //device const uint4 * x0p = (device const uint4 *) (x + i)->qs; - //device const float4 * y0p = (device const float4 *) (y + i*qk); - - //const uint4 x0 = *x0p; - - //const uint4 x0l = (x0 & uint4(0x0F0F0F0F)); - //const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4; - - //thread const char * x0lsb = (thread const char *) &x0l; - //thread const char * x0hsb = (thread const char *) &x0h; - - //const float4 y00 = *(y0p + 0); - //const float4 y01 = *(y0p + 1); - //const float4 y02 = *(y0p + 2); - //const float4 y03 = *(y0p + 3); - //const float4 y04 = *(y0p + 4); - //const float4 y05 = *(y0p + 5); - //const float4 y06 = *(y0p + 6); - //const float4 y07 = *(y0p + 7); - - //const half d = (x + i)->d; - - //sum[tpitg.x] += ( - // (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] + - // (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] + - // (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] + - // (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] + - // (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] + - // (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] + - // (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] + - // (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3] - // ) * d; - device const uchar * x0p = (device const uchar *) (x + i)->qs; device const float * y0p = (device const float *) (y + i*qk); @@ -335,6 +302,47 @@ kernel void kernel_mul_mat_q4_0( } } +kernel void kernel_mul_mat_f16_f32( + device const half * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const half * x = src0 + r0*ne00; + device const float * y = src1 + r1*ne10; + + sum[tpitg.x] = 0.0f; + + for (int i = tpitg.x; i < ne00; i += tptg.x) { + sum[tpitg.x] += (float) x[i] * (float) y[i]; + } + + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = tptg.x/2; i > 0; i /= 2) { + if (tpitg.x < i) { + sum[tpitg.x] += sum[tpitg.x + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tpitg.x == 0) { + dst[r1*ne0 + r0] = sum[0]; + } +} + kernel void kernel_rope( device const void * src0, device float * dst,