From 35cd10c173d362c243809345ed460b6e61f108e4 Mon Sep 17 00:00:00 2001 From: lshzh-ww Date: Wed, 30 Aug 2023 23:46:42 -0400 Subject: [PATCH] metal: yet another MUL mat-vec template --- ggml-metal.m | 5 +++- ggml-metal.metal | 69 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f01cafe10..fc656fb79 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -864,7 +864,10 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; - [encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0]; + // only for k-quants we use threadgroup memory + if (ggml_blck_size(src0t) >= 64){ + [encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0]; + } [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; } else { switch (src0->type) { diff --git a/ggml-metal.metal b/ggml-metal.metal index afc8977d6..355b0f380 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1320,6 +1320,67 @@ kernel void kernel_mat_mv(device const void * src0, } } +template class quant_dri> +kernel void kernel_mat_mv_no_tg_mem(device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uint * shared_memory[[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/(nl * 16); + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int ix = tiisg / nl; + const int il = tiisg % nl; + const int first_row = (r0 * nsg) * nr + sgitg; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0) + ix; + const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16) + (il/(n_shift/8))*16*(n_shift/8) + (il%(n_shift/8)) * 8; + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + offset1; + + float4x4 yl; // src1 vector cache + float sumf[nr] = {0.f}; + + quant_dri dequan_worker; + dequan_worker.init(il); + + // each thread in a SIMD group deals with 16 dequantized weights. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / nl) { + yl[0] = *((device const float4 *)y); + yl[1] = *((device const float4 *)y + 1); + yl[2] = *((device const float4 *)y + n_shift/4); + yl[3] = *((device const float4 *)y + n_shift/4 + 1); + + dequan_worker.inner_product_pre(il, yl); + #pragma unroll(nr) + for (int row = 0; row < nr; row++) { + float sum_temp = 0.f; + dequan_worker.inner_product(x + 2 * nb * row, il, yl, sum_temp); + sumf[row] += sum_temp; + } + x += N_SIMDWIDTH / nl; + y += N_SIMDWIDTH * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + nsg * row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + nsg * row] = tot; + } + } +} + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A #define BLOCK_SIZE_K 32 @@ -1487,10 +1548,10 @@ typedef void (mat_mv_t)(device const void *, device const float *, device float #define N_DST 4 #define N_SIMDGROUP 2 -template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv; -template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv; -template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv; -template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv; +template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv; template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv; template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv;