From fff0e4f9d8fe67e70672e00c3c5b72b51c054511 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 10 Jun 2023 15:57:33 +0300 Subject: [PATCH] metal : still optimizing Q4_K This commit pushes it down to 25.3 ms / token. The crazy idea of using 6 bits for the scales is really costly on Metal: if I remove the bit fiddling necessary to make the block scales, time goes almost to the Q4_0 23 ms/token. Before pushing the k-quants upstream I had a Q4_K variant that had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight, was running slightly slower on the CPU (due to the larger model size and being memory bound there), and the difference was entirely negligible under CUDA. So, I decided to publish the version with 6-bit scales. Perhaps I should re-consider and change to 8-bit scales? --- ggml-metal.metal | 68 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 954902249..789aa5d5e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -986,42 +986,82 @@ kernel void kernel_mul_mat_q4_k_f32( const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 + //const int ir = tid%4; // 0...3 + const int ir = tid - 4*il;// 0...3 const int n = 4; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; sum[ith] = 0.0f; + //uint16_t aux_scales[4]; + //thread uint8_t * sc = (thread uint8_t *)aux_scales; + + //uint32_t aux32[4]; + //thread const uint8_t * sc = (thread const uint8_t *)aux32; + + uchar2 sc1, sc2, sc3, sc4; + float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * q1 = (x + i)->qs + 32*im + l0; - device const float * y1 = yy + i*QK_K + 64*im + l0; + device const uint8_t * q1 = (x + i)->qs + q_offset; device const uint8_t * q2 = q1 + 64; + device const float * y1 = yy + i*QK_K + y_offset; device const float * y2 = y1 + 128; - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - const float dall = (float)((x + i)->d); const float dmin = (float)((x + i)->dmin); - const uchar2 sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - const uchar2 sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - const uchar2 sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - const uchar2 sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + //device const uint32_t * a = (device const uint32_t *)(x + i)->scales; + //aux32[0] = a[0] & 0x3f3f3f3f; // scales for 0, 32, 64, 96 + //aux32[1] = a[1] & 0x3f3f3f3f; // mins for 0, 32, 64, 96 + //aux32[2] = ((a[2] >> 0) & 0x0f0f0f0f) | ((a[0] & 0xc0c0c0c0) >> 2); // scales for 128, 160, 192, 224 + //aux32[3] = ((a[2] >> 4) & 0x0f0f0f0f) | ((a[1] & 0xc0c0c0c0) >> 2); // mins for 128, 160, 192, 224 - float2 s = {0.f, 0.f}; + //aux_scales[0] = (uint16_t)(a[im+0] & kmask1); + //aux_scales[1] = (uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)); + //aux_scales[2] = (uint16_t)(a[im+2] & kmask1); + //aux_scales[3] = (uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)); + + device const uint16_t * a = (device const uint16_t *)(x + i)->scales; + sc1 = as_type((uint16_t)(a[im+0] & kmask1)); + sc2 = as_type((uint16_t)(a[im+2] & kmask1)); + sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); + sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + + //float2 s = {0.f, 0.f}; + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; for (int l = 0; l < n; ++l) { - s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4) - + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4); - s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + + ////s[0] += y1[l] * sc[0] * (q1[l] & 0xF) + y1[l+32] * sc[1] * (q1[l] >> 4) + //// + y2[l] * sc[2] * (q2[l] & 0xF) + y2[l+32] * sc[3] * (q2[l] >> 4); + ////s[1] += y1[l] * sc[4] + y1[l+32] * sc[5] + y2[l] * sc[6] + y2[l+32] * sc[7]; + + ////s[0] += y1[l] * sc[2*im+0] * (q1[l] & 0xF) + y1[l+32] * sc[2*im+1] * (q1[l] >> 4) + //// + y2[l] * sc[2*im+8] * (q2[l] & 0xF) + y2[l+32] * sc[2*im+9] * (q2[l] >> 4); + ////s[1] += y1[l] * sc[2*im+4] + y1[l+32] * sc[2*im+5] + y2[l] * sc[2*im+12] + y2[l+32] * sc[2*im+13]; + + //s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4) + // + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4); + //s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + + s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); + s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); + smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + } - sumf += dall * s[0] - dmin * s[1]; + //sumf += dall * s[0] - dmin * s[1]; + sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } + sum[ith] = sumf; //