From ae58ac7dd40fa0e29e4fc0472e2a3ea3f118b808 Mon Sep 17 00:00:00 2001 From: Matteo Boschini Date: Mon, 31 Jul 2023 00:02:04 +0200 Subject: [PATCH] Added gqa8 kernel to allow llama-2-70B on metal --- ggml-metal.m | 17 +++++++++++----- ggml-metal.metal | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 74a6bff40..a7a6fd428 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -65,6 +65,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); @@ -182,6 +183,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); @@ -718,7 +720,8 @@ void ggml_metal_graph_compute( // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne02 == ne12); + int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64; + GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && @@ -749,8 +752,8 @@ void ggml_metal_graph_compute( // we need to do ne02 multiplications // TODO: is there a way to do this in parallel - currently very slow .. // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS - for (int64_t i02 = 0; i02 < ne02; ++i02) { - size_t offs_src0_cur = offs_src0 + i02*nb02; + for (int64_t i02 = 0; i02 < ne12; ++i02) { + size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now size_t offs_src1_cur = offs_src1 + i02*nb12; size_t offs_dst_cur = offs_dst + i02*nb2; @@ -772,11 +775,15 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F16: { - GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); nth0 = 64; nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + if (llama_2_70_gqa_step) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } } break; case GGML_TYPE_Q4_0: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 696b33ce7..a6c6fc1c8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -552,6 +552,56 @@ kernel void kernel_mul_mat_f16_f32( } } +kernel void kernel_mul_mat_f16_f32_gqa8( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpig[[thread_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + sum[tpitg.x] = 0.0f; + + for (int i = tpitg.x; i < ne00; i += tptg.x) { + sum[tpitg.x] += (float) x[i] * (float) y[i]; + } + + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = tptg.x/2; i > 0; i /= 2) { + if (tpitg.x < i) { + sum[tpitg.x] += sum[tpitg.x + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tpitg.x == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + } +} + + kernel void kernel_alibi_f32( device const float * src0, device float * dst,