diff --git a/ggml-metal.m b/ggml-metal.m index 79902c9a8..3320e9a64 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -87,6 +87,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, @@ -543,6 +544,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE, mul_mv_f16_f32_l4_large, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); @@ -1637,6 +1639,7 @@ static enum ggml_status ggml_metal_graph_compute( //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); id pipeline = nil; + bool is_large = false; // use custom matrix x vector kernel switch (src0t) { @@ -1654,7 +1657,12 @@ static enum ggml_status ggml_metal_graph_compute( if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + if (ne01 > 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline; + is_large = true; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + } nrows = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; @@ -1844,7 +1852,11 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (is_large) { + [encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index e2796fd60..18032835e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1645,6 +1645,64 @@ kernel void kernel_mul_mv_f16_f32_l4( } } +kernel void kernel_mul_mv_f16_f32_l4_large( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t base_r0 = tgpig.x*32; + const int64_t im = tgpig.z; + threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + for (int j = 0; j < 32; ++j) { + const int64_t r0 = base_r0 + j; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + device const half4 * x4 = (device const half4 *) (src0 + offset0); + + partial_sums[tiisg] = 0.0f; + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k]; + } + + // Barrier to ensure all threads have written their partial sums + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumf = simd_sum(partial_sums[tiisg]); + // Barrier to ensure reduction is complete before writing the result + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = sumf; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +} + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y));