From 167a0bbe3406909fd8ff6dd9845c2d9281a1f728 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 12:34:07 +0300 Subject: [PATCH] k_quants: WIP super-blocks with 64 weights Q4_K works on Metal and is actually slightly faster than QK_K = 256 (21.95 ms vs 24.0 ms). --- ggml-metal.metal | 65 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 03f599d80..11f390d70 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -781,6 +781,12 @@ kernel void kernel_cpy_f32_f32( static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); #endif +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif + typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants @@ -797,13 +803,19 @@ typedef struct { } block_q3_k; // 110 bytes / block +#if QK_K == 64 +typedef struct { + half4 d; // super-block scales/mins + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_k; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_k; -// 144 bytes / block +#endif typedef struct { half d; // super-block scale for quantized scales @@ -934,10 +946,12 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int i = 0; i < nb; i++) { + device const uint8_t * q = x[i].qs; + +#if QK_K == 256 const float d = x[i].d; const float min = x[i].dmin; - device const uint8_t * q = x[i].qs; device const uint8_t * scales = x[i].scales; int is = 0; @@ -949,6 +963,14 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + float4 d4 = (float4)x[i].d; + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1]; + y[l+32] = d4[2] * (q[l] >> 4) - d4[3]; + } + y += QK_K; +#endif } } @@ -1332,11 +1354,15 @@ kernel void kernel_mul_mat_q4_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 @@ -1350,11 +1376,8 @@ kernel void kernel_mul_mat_q4_k_f32( const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; - sum[ith] = 0.0f; - uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1383,6 +1406,22 @@ kernel void kernel_mul_mat_q4_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4*tpitg.x; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i * QK_K + il; + + const float4 d4 = (float4)x[i].d; + + for (int l = 0; l < 4; ++l) { + sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16]) + + d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]); + } + } +#endif sum[ith] = sumf; @@ -1593,12 +1632,14 @@ kernel void kernel_mul_mat_q6_k_f32( const float d = x[i].d; + float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) - + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) - + y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); } #endif