From b3d55bcc726781b4826ebda19c2d8c839a7b8839 Mon Sep 17 00:00:00 2001 From: Alexander Komarov Date: Fri, 24 May 2024 11:52:13 -0700 Subject: [PATCH] replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large for vectors larger than 128 elements. --- ggml-metal.metal | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/ggml-metal.metal b/ggml-metal.metal index 8ff70d7a7..b864639ff 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4( } } +kernel void kernel_mul_mv_f16_f32_l4_large( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t base_r0 = tgpig.x*32; + const int64_t im = tgpig.z; + threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + for (int j = 0; j < 32; ++j) { + const int64_t r0 = base_r0 + j; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + device const half4 * x4 = (device const half4 *) (src0 + offset0); + + partial_sums[tiisg] = 0.0f; + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k]; + } + + // Barrier to ensure all threads have written their partial sums + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumf = simd_sum(partial_sums[tiisg]); + // Barrier to ensure reduction is complete before writing the result + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = sumf; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +} + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y));