From 5f2e4bd8ba31495df223eea47071efc29d6b7ce1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 20 Jul 2023 16:33:15 +0300 Subject: [PATCH] Another Q5_K speedup --- ggml-metal.metal | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index f89608eda..f71e8f33b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1656,22 +1656,23 @@ kernel void kernel_mul_mat_q5_K_f32( device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - float yl[8], yh[8]; float sumf[2]={0.f}; const int step = sizeof(block_q5_K) * nb; #if QK_K == 256 +# + float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int im = tid/8; - const int ir = tid%8; - const int n = 4; + const int tid = tiisg/4; + const int ix = tiisg%4; + const int im = tid/4; + const int ir = tid%4; + const int n = 8; const int l0 = n*ir; const int q_offset = 32*im + l0; @@ -1687,7 +1688,7 @@ kernel void kernel_mul_mat_q5_K_f32( device const float * y1 = yy + ix*QK_K + y_offset; - for (int i = ix; i < nb; i += 2) { + for (int i = ix; i < nb; i += 4) { device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * qh = x[i].qh + l0; @@ -1696,11 +1697,11 @@ kernel void kernel_mul_mat_q5_K_f32( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { + for (int l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+4] = y1[l+32]; sumy[1] += yl[l+4]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+4] = y2[l+32]; sumy[3] += yh[l+4]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } for (int row = 0; row < 2; ++row) { @@ -1716,9 +1717,9 @@ kernel void kernel_mul_mat_q5_K_f32( for (int l = 0; l < n; ++l) { uint8_t h = qh[l]; acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); - acc[1] += yl[l+4] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); + acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); - acc[3] += yh[l+4] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); } const float dall = dh[0]; const float dmin = dh[1]; @@ -1732,10 +1733,12 @@ kernel void kernel_mul_mat_q5_K_f32( } - y1 += 2 * QK_K; + y1 += 4 * QK_K; } #else + float yl[8], yh[8]; + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 const int ix = tiisg%8; const int im = il/8; // 0, 0, 1, 1 @@ -1747,10 +1750,10 @@ kernel void kernel_mul_mat_q5_K_f32( float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; //sumy[0] += yl[l+0]; - yl[l+4] = y[l+16]; //sumy[1] += yl[l+4]; - yh[l+0] = y[l+32]; //sumy[2] += yh[l+0]; - yh[l+4] = y[l+48]; //sumy[3] += yh[l+4]; + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; } device const half * dh = &x[i].d;