From b42dfdcd89e1905e8bb63ac5e94cadb3908503d5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 11 Sep 2023 08:14:45 +0200 Subject: [PATCH] metal: very slightly faster TG for Q4_K --- ggml-metal.metal | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index a6a3354e9..a7f9c9b6d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -733,7 +733,10 @@ kernel void kernel_mul_mat_f16_f32_old( } for (int row = 0; row < N_F16_F32; ++row) { float all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && rb + row < ne11) { + if (rb + row >= ne11) { + break; + } + if (tiisg == 0) { dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum; } } @@ -1545,8 +1548,7 @@ kernel void kernel_mul_mat_q4_K_f32( const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[16]; - float yh[16]; + float4 yl[8]; float sumf[N_DST]={0.f}, all_sum; const int step = sizeof(block_q4_K) * nb / 2; @@ -1556,14 +1558,14 @@ kernel void kernel_mul_mat_q4_K_f32( uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f}; + for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; + float4 sumy = {0.f}; for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + yl[i] = {y4[i], y4[i+32], y4[i+128], y4[i+160]}; + sumy += yl[i]; } device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; @@ -1582,22 +1584,23 @@ kernel void kernel_mul_mat_q4_K_f32( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + + acc1[0] += yl[i+0][0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+0][1] * (q1[i/2] & 0x00F0); + acc1[2] += yl[i+0][2] * (q2[i/2] & 0x000F); + acc1[3] += yl[i+0][3] * (q2[i/2] & 0x00F0); + acc2[0] += yl[i+1][0] * (q1[i/2] & 0x0F00); + acc2[1] += yl[i+1][1] * (q1[i/2] & 0xF000); + acc2[2] += yl[i+1][2] * (q2[i/2] & 0x0F00); + acc2[3] += yl[i+1][3] * (q2[i/2] & 0xF000); + } float dall = dh[0]; float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + acc1 += acc2 / 256.f; + acc1 *= norm; + sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += step;