From 4560acced5aeec7af14ef507bd88608778920e4d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 8 Sep 2023 16:34:07 +0200 Subject: [PATCH] Another faster f16 x f32 matrix multiply kernel --- ggml-metal.m | 10 +++++++++- ggml-metal.metal | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 81edc5046..fd24f1412 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -78,6 +78,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); @@ -223,6 +224,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); @@ -290,6 +292,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); @@ -872,6 +875,7 @@ void ggml_metal_graph_compute( } else { int nth0 = 32; int nth1 = 1; + int nrows = 1; // use custom matrix x vector kernel switch (src0t) { @@ -881,8 +885,12 @@ void ggml_metal_graph_compute( nth1 = 1; if (ne11 * ne12 < 4) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + nrows = ne11; } else { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + nrows = 4; } } break; case GGML_TYPE_Q4_0: @@ -1003,7 +1011,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + 3)/4; + int64_t ny = (ne11 + nrows - 1)/nrows; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } diff --git a/ggml-metal.metal b/ggml-metal.metal index 2da0a1459..242c26889 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -630,6 +630,50 @@ kernel void kernel_mul_mat_f16_f32( } } +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mat_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + for (int r1 = 0; r1 < nrows; ++r1) { + + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + kernel void kernel_alibi_f32( device const float * src0, device float * dst,