From d28b07ca7cecf9f0ec86e051a615df42aa12ae6d Mon Sep 17 00:00:00 2001 From: Matteo Boschini Date: Mon, 31 Jul 2023 14:41:23 +0200 Subject: [PATCH] Extend kernel_mul_mat_f16_f32 to handle gqa broadcast --- ggml-metal.m | 35 +++++++++++++------------------- ggml-metal.metal | 53 +++--------------------------------------------- 2 files changed, 17 insertions(+), 71 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c079e7324..076788836 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -65,7 +65,6 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); @@ -183,7 +182,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); @@ -720,8 +718,7 @@ void ggml_metal_graph_compute( // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 GGML_ASSERT(ne00 == ne10); - int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64; - GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); + // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && @@ -775,15 +772,9 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F16: { - GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); - nth0 = 64; nth1 = 1; - if (llama_2_70_gqa_step) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8]; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; } break; case GGML_TYPE_Q4_0: { @@ -860,16 +851,18 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { diff --git a/ggml-metal.metal b/ggml-metal.metal index a6c6fc1c8..8d26b5ec2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant int64_t & ne02, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -529,56 +531,7 @@ kernel void kernel_mul_mat_f16_f32( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - sum[tpitg.x] = 0.0f; - - for (int i = tpitg.x; i < ne00; i += tptg.x) { - sum[tpitg.x] += (float) x[i] * (float) y[i]; - } - - // accumulate the sum from all threads in the threadgroup - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = tptg.x/2; i > 0; i /= 2) { - if (tpitg.x < i) { - sum[tpitg.x] += sum[tpitg.x + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - if (tpitg.x == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - } -} - -kernel void kernel_mul_mat_f16_f32_gqa8( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpig[[thread_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); sum[tpitg.x] = 0.0f;