From 92a0c17474000e46e914825660b5694d9d82ca03 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 7 Feb 2024 11:20:04 +0200 Subject: [PATCH 1/5] metal : initial working version --- ggml-metal.m | 91 ++++++++- ggml-metal.metal | 369 +++++++++++++++++++++++++++++++------ tests/test-backend-ops.cpp | 38 +++- 3 files changed, 439 insertions(+), 59 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c1d8e2de8..6d051e8ab 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -116,6 +116,21 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -488,6 +503,21 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32, mul_mm2_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32, mul_mm2_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32, mul_mm2_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32, mul_mm2_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32, mul_mm2_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32, mul_mm2_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32, mul_mm2_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32, mul_mm2_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32, mul_mm2_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32, mul_mm2_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32, mul_mm2_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32, mul_mm2_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32, mul_mm2_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32, mul_mm2_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32, mul_mm2_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -1271,7 +1301,66 @@ static bool ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + if (src1t == GGML_TYPE_F32 && ne11 <= 8) { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32].pipeline; break; + default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + + const int nsg = 8; + + const int nsg0 = 1; + const int nsh0 = 8; + const int nsg1 = 1; + const int nsh1 = 64; + + GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4 + GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16 + GGML_ASSERT(nsh0 % 2 == 0); // dequantize in chunks of 2x8 = 16 + GGML_ASSERT(nsh1 % nsh0 == 0); + GGML_ASSERT(nsh0 >= 2*nsg1); // need enough memory to store the results in f32 + + const size_t shmem = nsg*(8*nsg0)*(8*nsh0)*(sizeof(float)/2) + (8*nsg1)*(8*nsh1)*sizeof(float); + + GGML_ASSERT(shmem <= 32*1024); + GGML_ASSERT(shmem >= nsg*(8*nsg0)*(8*nsg1)*sizeof(float)); + + [encoder setThreadgroupMemoryLength:shmem atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 8*nsg0*nsg - 1)/(8*nsg0*nsg), (ne11 + 8*nsg1 - 1)/(8*nsg1), ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && diff --git a/ggml-metal.metal b/ggml-metal.metal index efed6ad46..8926ec6bb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4650,25 +4650,28 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +void kernel_mul_mm_impl( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); @@ -4781,6 +4784,194 @@ void kernel_mul_mm_impl(device const uchar * src0, } } +#define NSG0 1 +#define NSH0 8 +#define NSG1 1 +#define NSH1 64 + +// each block_q contains 16*nl weights +template +void kernel_mul_mm2_impl( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_u8 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const uint nsg = ntg.y; // number of simdgroups + + const int64_t im = tgpig[2]; + const int64_t i11 = tgpig[1]*(8*NSG1); + const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); + + const int64_t i12 = im%ne12; + const int64_t i13 = im/ne12; + + const int64_t ne01 = ne0; + const int64_t ne11 = ne1; + + const int64_t NW = N_SIMDWIDTH; + + const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) + const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) + + const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) + const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4) + + const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float) + const int64_t T14 = T1/4; // row of src1 in shared memory in (float4) + + threadgroup half * shared = (threadgroup half *) shared_u8; + + threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); + threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); + threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); + threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); + + threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1)); + + simdgroup_half8x8 m0[NSG0]; + simdgroup_float8x8 m1[NSG1]; + simdgroup_float8x8 mr[NSG0][NSG1]; + + // zero out shared memory SH0 for src0 + for (int64_t i = tiisg; i < SH04; i += NW) { + s04[i] = 0.0h; + } + + // zero out shared memory SH1 for src1 + for (int64_t i = tiitg; i < SH14; i += nsg*NW) { + s14[i] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // initialize mr + for (int j = 0; j < NSG0; j++) { + for (int i = 0; i < NSG1; i++) { + mr[j][i] = make_filled_simdgroup_matrix(0.f); + } + } + + for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) { + // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory + { + threadgroup_barrier(mem_flags::mem_threadgroup); + + const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1); + + for (int64_t i = tiitg; i < nload; i += nsg*NW) { + const int64_t ic = i%(8*NSH1); + const int64_t ir = i/(8*NSH1); + + // TODO: use float4 + device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + + if (i00 + ic < ne00) { + s1[8*NSH1*ir + ic] = *p1; + } else { + s1[8*NSH1*ir + ic] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { + // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory + { + const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0); + + half4x4 tmp0; + + for (int64_t i = 16*tiisg; i < nload; i += 16*NW) { + const int64_t ic = i%(8*NSH0); + const int64_t ir = i/(8*NSH0); + + const int64_t icc = i00 + 8*b0*NSH0 + ic; + + const int64_t ib = (icc/(16*nl)); + const int64_t il = (icc%(16*nl))/16; + + device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib; + + dequantize_func(p0, il, tmp0); + + for (int k = 0; k < 4; k++){ + if (icc + 4*k < ne00) { + s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; + } else { + s04[(8*NSH0*ir + ic)/4 + k] = 0.0h; + } + } + } + } + + simdgroup_barrier(mem_flags::mem_none); + +#pragma unroll(NSH0) + for (int k = 0; k < NSH0; ++k) { + for (int j = 0; j < NSG0; ++j) { + simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0); + } + + for (int i = 0; i < NSG1; ++i) { + simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true); + } + + for (int j = 0; j < NSG0; ++j) { + for (int i = 0; i < NSG1; ++i) { + simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]); + } + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // write the mr to shared memory + + for (int i = 0; i < NSG1; i++) { + for (int j = 0; j < NSG0; j++) { + simdgroup_store(mr[j][i], r0 + (8*i)*(8*NSG0) + 8*j, 8*NSG0, 0, true); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * pdst = dst + im*ne1*ne0; + + for (int is = 0; is < NSG1; is++) { + const int64_t i1 = i11 + is*8; + const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0); + + for (int64_t i = tiisg; i < nstore; i += NW) { + const int64_t ic = i%(8*NSG0); + const int64_t ir = i/(8*NSG0); + + if (i1 + ir < ne1 && i01 + ic < ne0) { + pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic]; + } + } + } +} + // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids template void kernel_mul_mm_id_impl( @@ -4802,7 +4993,9 @@ void kernel_mul_mm_id_impl( constant uint & r3, threadgroup uchar * shared_memory, uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup half * sa = (threadgroup half *)(shared_memory); @@ -4907,25 +5100,28 @@ void kernel_mul_mm_id_impl( } template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +kernel void kernel_mul_mm( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mm_impl( src0, src1, @@ -4944,7 +5140,56 @@ kernel void kernel_mul_mm(device const uchar * src0, r3, shared_memory, tgpig, + ntg, tiitg, + tiisg, + sgitg); +} + +template +kernel void kernel_mul_mm2( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mm2_impl( + src0, + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + shared_memory, + tgpig, + ntg, + tiitg, + tiisg, sgitg); } @@ -4979,7 +5224,9 @@ kernel void kernel_mul_mm_id( device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], + uint3 ntg[[threads_per_threadgroup]], uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; @@ -5017,7 +5264,9 @@ kernel void kernel_mul_mm_id( r3, shared_memory, tgpig, + ntg, tiitg, + tiisg, sgitg); } @@ -5082,24 +5331,40 @@ typedef void (mat_mm_t)( constant uint & r2, constant uint & r3, threadgroup uchar *, - uint3, uint, uint); + uint3, uint3, uint, uint, uint); -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm2_f32_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_f16_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm2; +template [[host_name("kernel_mul_mm2_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm2; + // // indirect matrix-matrix multiplication // @@ -5133,7 +5398,7 @@ typedef void (mat_mm_id_t)( device const uchar * src06, device const uchar * src07, threadgroup uchar *, - uint3, uint, uint); + uint3, uint3, uint, uint, uint); template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index eb06123d2..3be0fe6cb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -480,12 +480,13 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); - //exit(1); + printf("[%s] NMSE = %.9f > %.9f", ggml_op_desc(t1), err, ud->max_err); + printf("\n"); + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); + exit(1); ud->ok = false; } return true; @@ -572,9 +573,19 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + n_runs = 256; + int n_nodes = gf->n_nodes; + for (int i = 0; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -2044,6 +2055,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } +#if 0 for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); @@ -2063,6 +2075,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } +#else + for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 1, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 2, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 3, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 4, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 5, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 6, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 7, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 8, 4096, { 1, 1}, {1, 1})); + } + } +#endif for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { From e68e32548fa1e824f0e7bfa8414b8d853efff808 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 7 Feb 2024 23:12:22 +0200 Subject: [PATCH 2/5] metal : opts --- ggml-metal.m | 2 +- ggml-metal.metal | 110 ++++++++++++++++++++++++++--------------------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d051e8ab..80cfb2e22 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute( const int nsg = 8; const int nsg0 = 1; - const int nsh0 = 8; + const int nsh0 = 16; const int nsg1 = 1; const int nsh1 = 64; diff --git a/ggml-metal.metal b/ggml-metal.metal index 8926ec6bb..41d6f78ea 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl( } #define NSG0 1 -#define NSH0 8 +#define NSH0 16 #define NSG1 1 #define NSH1 64 @@ -4815,33 +4815,34 @@ void kernel_mul_mm2_impl( uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint nsg = ntg.y; // number of simdgroups - const int64_t im = tgpig[2]; - const int64_t i11 = tgpig[1]*(8*NSG1); - const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); + const int im = tgpig[2]; + const int i11 = tgpig[1]*(8*NSG1); + const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); - const int64_t i12 = im%ne12; - const int64_t i13 = im/ne12; + const int i12 = im%ne12; + const int i13 = im/ne12; - const int64_t ne01 = ne0; - const int64_t ne11 = ne1; + const int ne01 = ne0; + const int ne11 = ne1; - const int64_t NW = N_SIMDWIDTH; + const int NW = N_SIMDWIDTH; - const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) - const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) + const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) + const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) - const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) - const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4) + const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) + const int SH14 = SH1/4; // shread memory for src1 data in (float4) - const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float) - const int64_t T14 = T1/4; // row of src1 in shared memory in (float4) + const int T1 = 8*NSH1; // row of src1 in shared memory in (float) + const int T14 = T1/4; // row of src1 in shared memory in (float4) threadgroup half * shared = (threadgroup half *) shared_u8; - threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); - threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); - threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); - threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); + threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); + threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); + threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0); + threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); + threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1)); @@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl( simdgroup_float8x8 mr[NSG0][NSG1]; // zero out shared memory SH0 for src0 - for (int64_t i = tiisg; i < SH04; i += NW) { + for (int i = tiisg; i < SH04; i += NW) { s04[i] = 0.0h; } // zero out shared memory SH1 for src1 - for (int64_t i = tiitg; i < SH14; i += nsg*NW) { + for (int i = tiitg; i < SH14; i += nsg*NW) { s14[i] = 0.0f; } @@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl( } } - for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) { + for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) { // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory { threadgroup_barrier(mem_flags::mem_threadgroup); - const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1); + const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1); - for (int64_t i = tiitg; i < nload; i += nsg*NW) { - const int64_t ic = i%(8*NSH1); - const int64_t ir = i/(8*NSH1); + const size_t offs0 = im*nb12; - // TODO: use float4 - device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) { + const int ic = i%(8*NSH1); + const int ir = i/(8*NSH1); - if (i00 + ic < ne00) { - s1[8*NSH1*ir + ic] = *p1; + device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + + if (i00 + ic + 4 <= ne00) { + s14[(8*NSH1*ir + ic)/4] = *p1; } else { - s1[8*NSH1*ir + ic] = 0.0f; + for (int k = 0; i00 + ic + k < ne00; k++){ + s1[8*NSH1*ir + ic + k] = (*p1)[k]; + } } } @@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl( for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory { - const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0); + const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0); half4x4 tmp0; - for (int64_t i = 16*tiisg; i < nload; i += 16*NW) { - const int64_t ic = i%(8*NSH0); - const int64_t ir = i/(8*NSH0); + const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02; - const int64_t icc = i00 + 8*b0*NSH0 + ic; + for (int i = 16*tiisg; i < nload; i += 16*NW) { + const int ic = i%(8*NSH0); + const int ir = i/(8*NSH0); - const int64_t ib = (icc/(16*nl)); - const int64_t il = (icc%(16*nl))/16; + const int icc = i00 + 8*b0*NSH0 + ic; - device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib; + const int ib = (icc/(16*nl)); + const int il = (icc%(16*nl))/16; + + device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib; dequantize_func(p0, il, tmp0); - for (int k = 0; k < 4; k++){ - if (icc + 4*k < ne00) { - s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; - } else { - s04[(8*NSH0*ir + ic)/4 + k] = 0.0h; + if (icc + 16 <= ne00) { + s016[(8*NSH0*ir + ic)/16] = tmp0; + } else { + for (int k = 0; k < 4; k++){ + if (icc + 4*k <= ne00) { + s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; + } else { + for (int p = 0; icc + 4*k + p < ne00; p++) { + s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p]; + } + } } } } @@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl( device float * pdst = dst + im*ne1*ne0; for (int is = 0; is < NSG1; is++) { - const int64_t i1 = i11 + is*8; - const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0); + const int i1 = i11 + is*8; + const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0); - for (int64_t i = tiisg; i < nstore; i += NW) { - const int64_t ic = i%(8*NSG0); - const int64_t ir = i/(8*NSG0); + for (int i = tiisg; i < nstore; i += NW) { + const int ic = i%(8*NSG0); + const int ir = i/(8*NSG0); if (i1 + ir < ne1 && i01 + ic < ne0) { pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic]; From 845876d0124224a19650c9e6553f4d6b3ad0ab43 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 8 Feb 2024 13:26:50 +0200 Subject: [PATCH 3/5] metal : works with ne00 % 4 == 0 --- ggml-metal.m | 2 +- ggml-metal.metal | 23 ++++++++++++++++------- tests/test-backend-ops.cpp | 24 ++++++++++++++---------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 80cfb2e22..831d2c93a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1348,7 +1348,7 @@ static bool ggml_metal_graph_compute( const int nsh1 = 64; GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4 - GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16 + //GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16 GGML_ASSERT(nsh0 % 2 == 0); // dequantize in chunks of 2x8 = 16 GGML_ASSERT(nsh1 % nsh0 == 0); GGML_ASSERT(nsh0 >= 2*nsg1); // need enough memory to store the results in f32 diff --git a/ggml-metal.metal b/ggml-metal.metal index 41d6f78ea..f795f5386 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4872,8 +4872,6 @@ void kernel_mul_mm2_impl( for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) { // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory { - threadgroup_barrier(mem_flags::mem_threadgroup); - const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1); const size_t offs0 = im*nb12; @@ -4884,11 +4882,17 @@ void kernel_mul_mm2_impl( device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + //float4 tmp0 = *p1; + //tmp0[0] = 1; tmp0[1] = 1; tmp0[2] = 1; tmp0[3] = 1; + if (i00 + ic + 4 <= ne00) { s14[(8*NSH1*ir + ic)/4] = *p1; } else { - for (int k = 0; i00 + ic + k < ne00; k++){ - s1[8*NSH1*ir + ic + k] = (*p1)[k]; + s14[(8*NSH1*ir + ic)/4] = 0.0f; + for (int k = 0; k < 4; k++){ + if (i00 + ic + k < ne00) { + s1[8*NSH1*ir + ic + k] = (*p1)[k]; + } } } } @@ -4918,11 +4922,16 @@ void kernel_mul_mm2_impl( dequantize_func(p0, il, tmp0); + //for (int z = 0; z < 16; z++) { + // tmp0[z/4][z%4] = 1; + //} + if (icc + 16 <= ne00) { s016[(8*NSH0*ir + ic)/16] = tmp0; } else { + s016[(8*NSH0*ir + ic)/16] = half4x4(0.0h); for (int k = 0; k < 4; k++){ - if (icc + 4*k <= ne00) { + if (icc + 4*k < ne00) { s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; } else { for (int p = 0; icc + 4*k + p < ne00; p++) { @@ -4953,9 +4962,9 @@ void kernel_mul_mm2_impl( } } } - } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); + } // write the mr to shared memory diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3be0fe6cb..ae323a384 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2076,16 +2076,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } #else - for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { - for (ggml_type type_b : {GGML_TYPE_F32}) { - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 1, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 2, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 3, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 4, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 5, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 6, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 7, 4096, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 8, 4096, { 1, 1}, {1, 1})); + for (int r0 = 0; r0 < 32; ++r0) { + for (int c0 = 0; c0 < 4096; c0 += 512) { + for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 1, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 2, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 3, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 4, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 5, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 6, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 7, 64 + c0, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 8, 64 + c0, { 1, 1}, {1, 1})); + } + } } } #endif From e8b00e29415df61145bc1d76b63b118da003f34f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 8 Feb 2024 16:39:38 +0200 Subject: [PATCH 4/5] metal : fix NSG1 > 1 --- ggml-metal.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index f795f5386..dba9935af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4980,7 +4980,7 @@ void kernel_mul_mm2_impl( for (int is = 0; is < NSG1; is++) { const int i1 = i11 + is*8; - const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0); + const int nstore = MIN(8, ne1 - i1) * (8*NSG0); for (int i = tiisg; i < nstore; i += NW) { const int ic = i%(8*NSG0); From 5a668ea00078a318ee11c7f18b068973ac43a5cf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 12 Feb 2024 19:21:57 +0200 Subject: [PATCH 5/5] metal : trying bs = 512 performance (wip) --- ggml-metal.m | 12 ++++++------ ggml-metal.metal | 32 ++++++++++++++++++++++++-------- tests/test-backend-ops.cpp | 15 ++++++++++++++- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 831d2c93a..47cf991ca 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1301,7 +1301,7 @@ static bool ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (src1t == GGML_TYPE_F32 && ne11 <= 8) { + if (src1t == GGML_TYPE_F32) { id pipeline = nil; switch (src0->type) { @@ -1340,12 +1340,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - const int nsg = 8; + const int nsg = 4; - const int nsg0 = 1; - const int nsh0 = 16; - const int nsg1 = 1; - const int nsh1 = 64; + const int nsg0 = 4; + const int nsh0 = 4; + const int nsg1 = 2; + const int nsh1 = 4; GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4 //GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16 diff --git a/ggml-metal.metal b/ggml-metal.metal index dba9935af..74e74da19 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl( } } -#define NSG0 1 -#define NSH0 16 -#define NSG1 1 -#define NSH1 64 +#define NSG0 4 +#define NSH0 4 +#define NSG1 2 +#define NSH1 4 // each block_q contains 16*nl weights template @@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl( } for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory { const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1); @@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl( } } } - - threadgroup_barrier(mem_flags::mem_threadgroup); } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory { @@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl( simdgroup_barrier(mem_flags::mem_none); +#if 0 #pragma unroll(NSH0) for (int k = 0; k < NSH0; ++k) { for (int j = 0; j < NSG0; ++j) { @@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl( } } } - } +#else +#pragma unroll(NSH0) + for (int k = 0; k < NSH0; ++k) { + for (int i = 0; i < NSG1; ++i) { + simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + for (int j = 0; j < NSG0; ++j) { + simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0); + for (int i = 0; i < NSG1; ++i) { + simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]); + } + } + } +#endif + } } // write the mr to shared memory diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ae323a384..a15856eae 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2075,7 +2075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } -#else +#elif 0 for (int r0 = 0; r0 < 32; ++r0) { for (int c0 = 0; c0 < 4096; c0 += 512) { for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { @@ -2092,6 +2092,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#elif 1 + for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + } + } #endif for (ggml_type type_a : all_types) {