From 5e2f67fe009a97ae11c5a3ea8418c819f118c931 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 10 Jun 2023 12:04:13 +0300 Subject: [PATCH] metal : small improvement for Q4_K --- ggml-metal.metal | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 0909c7b30..954902249 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -971,7 +971,7 @@ kernel void kernel_mul_mat_q4_k_f32( const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0x0303; + const uint16_t kmask3 = 0xc0c0; const int nb = ne00/QK_K; @@ -991,16 +991,15 @@ kernel void kernel_mul_mat_q4_k_f32( const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; + const int l0 = n*(2*ir + in); sum[ith] = 0.0f; - //uchar2 sc1, sc2; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * q1 = (x + i)->qs + 32*im + n*(2*ir + in); - device const float * y1 = yy + i*QK_K + 64*im + n*(2*ir + in); + device const uint8_t * q1 = (x + i)->qs + 32*im + l0; + device const float * y1 = yy + i*QK_K + 64*im + l0; device const uint8_t * q2 = q1 + 64; device const float * y2 = y1 + 128; @@ -1011,21 +1010,16 @@ kernel void kernel_mul_mat_q4_k_f32( const uchar2 sc1 = as_type((uint16_t)(a[im+0] & kmask1)); const uchar2 sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - const uchar2 sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | (((a[im+0] >> 6) & kmask3) << 4))); - const uchar2 sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | (((a[im+2] >> 6) & kmask3) << 4))); + const uchar2 sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); + const uchar2 sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); - float4 s1 = {0.f, 0.f, 0.f, 0.f}; - float4 s2 = {0.f, 0.f, 0.f, 0.f}; + float2 s = {0.f, 0.f}; for (int l = 0; l < n; ++l) { - s1[0] += y1[l+ 0] * (q1[l] & 0xF); s1[1] += y1[l+ 0]; - s1[2] += y1[l+32] * (q1[l] >> 4); s1[3] += y1[l+32]; - s2[0] += y2[l+ 0] * (q2[l] & 0xF); s2[1] += y2[l+ 0]; - s2[2] += y2[l+32] * (q2[l] >> 4); s2[3] += y2[l+32]; + s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4) + + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4); + s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; } - sumf += dall * (s1[0] * sc1[0] + s1[2] * sc1[1] - + s2[0] * sc3[0] + s2[2] * sc3[1]) - - dmin * (s1[1] * sc2[0] + s1[3] * sc2[1] - + s2[1] * sc4[0] + s2[3] * sc4[1]); + sumf += dall * s[0] - dmin * s[1]; } sum[ith] = sumf;