From af226bd26e890e41621881a41a8675a76f178f29 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 10:14:42 +0300 Subject: [PATCH] Somewhat faster f16 x f32 matrix multiply kernel --- ggml-metal.metal | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 82e1a0c7a..02db5323e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - sum[tpitg.x] = 0.0f; + uint ith = tpitg.x; + uint nth = tptg.x; - for (int i = tpitg.x; i < ne00; i += tptg.x) { - sum[tpitg.x] += (float) x[i] * (float) y[i]; + sum[ith] = 0.0f; + + for (int i = ith; i < ne00; i += nth) { + sum[ith] += (float) x[i] * (float) y[i]; } // accumulate the sum from all threads in the threadgroup threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = tptg.x/2; i > 0; i /= 2) { - if (tpitg.x < i) { - sum[tpitg.x] += sum[tpitg.x + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - - if (tpitg.x == 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; } + + // Original implementation. Left behind commented out for now + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = tptg.x/2; i > 0; i /= 2) { + // if (tpitg.x < i) { + // sum[tpitg.x] += sum[tpitg.x + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + // + //if (tpitg.x == 0) { + // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_alibi_f32(