From 677ee9ff5926335693ca74775d87063da09a40f9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Dec 2024 11:08:16 +0200 Subject: [PATCH] metal : add rest of types ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 100 +++++++-- ggml/src/ggml-metal/ggml-metal.metal | 318 ++++++++++++++++++++++----- tests/test-backend-ops.cpp | 16 +- 3 files changed, 360 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d7ee7320c..0aec92a26 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -199,6 +199,22 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, @@ -747,6 +763,22 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); @@ -1978,17 +2010,28 @@ static void ggml_metal_encode_node( // to the matrix-vector kernel int ne11_mm_min = 4; - if ((src0t == GGML_TYPE_F16 || // TODO: helper function - src0t == GGML_TYPE_Q4_0 || - src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || - src0t == GGML_TYPE_Q8_0 - ) && - src1t == GGML_TYPE_F32 && - (ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8 - (ne11 >= 2 && ne11 <= 8)) { - + if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && + ( + ( + ( + src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 12) + ) + ) + ) { // TODO: determine the optimal parameters based on grid utilization const int nsg = 2; // TODO: or 4? const int nxpsg = ne11 < 3 ? 16 : 8; @@ -2010,9 +2053,6 @@ static void ggml_metal_encode_node( r1ptg = 5; break; }; - assert(nxpsg >= 8); - assert(nxpsg%8 == 0); - id pipeline = nil; switch (src0->type) { @@ -2064,6 +2104,38 @@ static void ggml_metal_encode_node( case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; default: GGML_ABORT("not implemented"); } break; + case GGML_TYPE_Q4_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q6_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_IQ4_NL: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; default: GGML_ABORT("not implemented"); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 50f33b6c6..3d15897f4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -128,7 +128,7 @@ void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & r } template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { +void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 3); const float d = xb->d; const float md = -16.h * xb->d; @@ -161,14 +161,36 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg template void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { - // TODO: implement - float4x4 tmp; - dequantize_q5_0(xb, il/4, tmp); - reg = tmp[il%4]; + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + md; + reg[2*ii + 1] = d * x1 + md; + } } template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { +void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 4); const float d = xb->d; const float m = xb->m; @@ -201,10 +223,32 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg template void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { - // TODO: implement - float4x4 tmp; - dequantize_q5_1(xb, il/4, tmp); - reg = tmp[il%4]; + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + m; + reg[2*ii + 1] = d * x1 + m; + } } template @@ -285,7 +329,7 @@ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q } template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { +void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { device const uchar * q = xb->qs; short is = (il/4) * 2; @@ -297,7 +341,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float dl = d * sc[0]; const float ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; + const ushort mask = il < 2 ? 0x0F : 0xF0; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - ml; } @@ -530,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 } } +template +void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; + reg[0] = d * kvalues_iq4nl_f[q8[0]]; + reg[1] = d * kvalues_iq4nl_f[q8[1]]; + reg[2] = d * kvalues_iq4nl_f[q8[2]]; + reg[3] = d * kvalues_iq4nl_f[q8[3]]; +} + template void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 @@ -1813,8 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template -void kernel_mul_mv_ext_q_f32_impl( +template +void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, device const char * src1, @@ -1840,7 +1897,7 @@ void kernel_mul_mv_ext_q_f32_impl( const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + (tx/8)*bp32 : (device const q_t *) src0; + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; device const float4 * y4[r1ptg]; @@ -1850,23 +1907,34 @@ void kernel_mul_mv_ext_q_f32_impl( float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; - for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { + short cch = tx%chpb; + + for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { + float4 lx[chpt]; + #pragma unroll(chpt) for (short ch = 0; ch < chpt; ++ch) { - float4 lx; + deq_t4(xq, cch, lx[ch]); - deq_t4(xq + (4*ch*nxpsg)/(32/bp32), tx%8, lx); - -#pragma unroll(r1ptg) - for (short ir1 = 0; ir1 < r1ptg; ++ir1) { - sumf[ir1] += dot(lx, y4[ir1][ch*nxpsg]); + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; } } - xq += ((4*chpt)*nxpsg)/(32/bp32); +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); + } + } + +#pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { - y4[ir1] += ((4*chpt)*nxpsg)/4; + y4[ir1] += chpt*nxpsg; } } @@ -1901,8 +1969,111 @@ void kernel_mul_mv_ext_q_f32_impl( } } -template -kernel void kernel_mul_mv_ext_q_f32_disp( +template +void kernel_mul_mv_ext_q4x4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 1; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4x4 * y4x4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; + + for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { + float4x4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4x4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += + dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + + dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + + dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + + dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] += chpt*nxpsg; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +template +kernel void kernel_mul_mv_ext_q4_f32_disp( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, device const char * src1, @@ -1911,43 +2082,82 @@ kernel void kernel_mul_mv_ext_q_f32_disp( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { switch (args.nxpsg) { - case 8: kernel_mul_mv_ext_q_f32_impl<8, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q_f32_impl<16, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q_f32_impl<32, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; } } -typedef decltype(kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q_f32_t; +template +kernel void kernel_mul_mv_ext_q4x4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, half4, 8, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, half4, 8, dequantize_f16_t4>; +typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_0, 1, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_0, 1, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_1, 1, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_1, 1, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_0, 1, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_0, 1, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_1, 1, dequantize_q5_1_t4>; -template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_1, 1, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q8_0, 1, dequantize_q8_0_t4>; -template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q8_0, 1, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; + +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; + +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; + +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; #define N_MV_T_T 4 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5aef6c9d1..7bd1cb5e0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3573,12 +3573,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); for (int i = 1; i < 64; ++i) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); } #if 1