From 285eeb15314de97f96283419804a8f8f92a293ec Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 14:11:47 +0300 Subject: [PATCH] k_quants: WIP super-blocks with 64 weights Q5_K works on Metal and is slightly faster than QK_K = 256 (23.7 ms vs 26.3 ms). --- ggml-metal.metal | 65 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ccd98956f..42c3a0412 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -820,6 +820,13 @@ typedef struct { } block_q4_k; #endif +#if QK_K == 64 +typedef struct { + half4 d; // super-block scales/mins + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_k; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins @@ -828,6 +835,7 @@ typedef struct { uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_k; // 176 bytes / block +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -1024,6 +1032,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 for (int i = 0; i < nb; i++) { const float d = (float)(x[i].d); @@ -1044,6 +1053,27 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i u1 <<= 2; u2 <<= 2; } } +#else + for (int i = 0; i < nb; i++) { + + const float4 d = (float4)x[i].d; + + device const uint8_t * ql = x[i].qs; + device const uint8_t * qh = x[i].qh; + + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1]; + y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1]; + y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1]; + y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1]; + y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3]; + y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3]; + y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3]; + y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3]; + } + y += QK_K; + } +#endif } @@ -1562,10 +1592,6 @@ kernel void kernel_mul_mat_q5_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -1577,6 +1603,14 @@ kernel void kernel_mul_mat_q5_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1596,7 +1630,6 @@ kernel void kernel_mul_mat_q5_k_f32( uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1628,6 +1661,28 @@ kernel void kernel_mul_mat_q5_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const float * y = yy + i*QK_K + il; + + const float4 d = (float4)x[i].d; + + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1]) + + y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1]) + + y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3]) + + y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]); + } + } +#endif sum[ith] = sumf; //