From d90b5981d0cf48e89743cbd30814951950b04a99 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 10 Sep 2023 15:26:06 +0200 Subject: [PATCH] 12% faster PP for Falcon --- ggml-metal.m | 31 ++++++++---- ggml-metal.metal | 119 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 120 insertions(+), 30 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index cadbfe04a..1c467a6da 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -902,22 +902,28 @@ void ggml_metal_graph_compute( } else { int nth0 = 32; int nth1 = 1; - int nrows = 1; + //int nrows = 1; + int nx = 1, ny = 1; // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F16: { - nth0 = 32; - nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + //[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + //nth0 = 32; + //nth1 = 1; + //if (ne11 * ne12 < 4) { + // [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; - nrows = ne11; + nx = ne01; + ny = 1; + nth0 = 32; } else { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - nrows = 4; + nth0 = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1; + nx = (ne01 + nth0 - 1)/nth0; + ny = ne11; } } break; case GGML_TYPE_Q4_0: @@ -1038,8 +1044,13 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + ////printf("f16xf32: %d x %d x %d, %d x %d x %d -> %d\n",(int)ne00,(int)ne01,(int)ne02, + //// (int)ne10,(int)ne11,(int)ne12,nrows); + //int64_t ny = (ne11 + nrows - 1)/nrows; + //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + //[encoder dispatchThreadgroups:MTLSizeMake(ne10*ne11*ne12, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + //int n = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1; + [encoder dispatchThreadgroups:MTLSizeMake(nx, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, 1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index c40a71a6b..a6a3354e9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -610,9 +610,68 @@ kernel void kernel_mul_mat_f16_f32_1row( } -#define N_F16_F32 4 - +#define N_F16_F32 8 +# kernel void kernel_mul_mat_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntptg[[threads_per_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]]) { + + // :MTLSizeMake(ne01, ne11, ne12) + const int64_t r0 = tgpig.x * ntptg.x + tiitg; + if (r0 >= ne0) { + return; + } + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + //const int64_t im = tiitg/(ne10*ne11); + //const int64_t r1 = (tiitg - im*ne10*ne11)/ne10; + //const int64_t r0 = tiitg - im*ne10*ne11 - r1*ne10; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + if (ne00 < 16) { + float sumf = 0; + for (int i = 0; i < ne00; ++i) { + sumf += (float) x[i] * (float) y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = sumf; + } + else { + float sumf = 0; + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = 0; i < ne00/4; ++i) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + for (int i = 4*(ne00/4); i < ne00; ++i) { + sumf += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = sumf; + + } +} + +kernel void kernel_mul_mat_f16_f32_old( device const char * src0, device const char * src1, device float * dst, @@ -639,7 +698,7 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - if (ne00 < 128) { + if (ne00 < 64) { //128) { for (int row = 0; row < N_F16_F32; ++row) { int r1 = rb + row; if (r1 >= ne11) { @@ -659,27 +718,47 @@ kernel void kernel_mul_mat_f16_f32( } } } else { + const int ix = tiisg/N_F16_F32; + const int iy = tiisg%N_F16_F32; + const int r1 = rb + iy < ne11 ? rb + iy : ne11-1; + float sumf[N_F16_F32] = {0.f}; device const half4 * x4 = (device const half4 *)x; + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + for (int i = ix; i < ne00/4; i += 32/N_F16_F32) { + for (int k = 0; k < 4; ++k) sumf[iy] += (float) x4[i][k] * y4[i][k]; + } + for (int i = 4*(ne00/4)+ix; i < ne00; i += 32/N_F16_F32) { + sumf[iy] += (float) x[i] * y[i]; + } for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + float all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && rb + row < ne11) { + dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum; } } + + //device const half4 * x4 = (device const half4 *)x; + //for (int row = 0; row < N_F16_F32; ++row) { + // int r1 = rb + row; + // if (r1 >= ne11) { + // break; + // } + + // device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + // device const float4 * y4 = (device const float4 *) y; + + // float sumf = 0; + // for (int i = tiisg; i < ne00/4; i += 32) { + // for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + // } + + // float all_sum = simd_sum(sumf); + // if (tiisg == 0) { + // for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + // dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + // } + //} } }