k_quants: WIP super-blocks with 64 weights

Q4_K works on Metal and is actually slightly faster
than QK_K = 256 (21.95 ms vs 24.0 ms).
This commit is contained in:
Iwan Kawrakow 2023-06-23 12:34:07 +03:00
parent e1bbcfc5cb
commit 167a0bbe34

View file

@ -781,6 +781,12 @@ kernel void kernel_cpy_f32_f32(
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
#endif #endif
#if QK_K == 256
#define K_SCALE_SIZE 12
#else
#define K_SCALE_SIZE 4
#endif
typedef struct { typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants uint8_t qs[QK_K/4]; // quants
@ -797,13 +803,19 @@ typedef struct {
} block_q3_k; } block_q3_k;
// 110 bytes / block // 110 bytes / block
#if QK_K == 64
typedef struct {
half4 d; // super-block scales/mins
uint8_t qs[QK_K/2]; // 4-bit quants
} block_q4_k;
#else
typedef struct { typedef struct {
half d; // super-block scale for quantized scales half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_k; } block_q4_k;
// 144 bytes / block #endif
typedef struct { typedef struct {
half d; // super-block scale for quantized scales half d; // super-block scale for quantized scales
@ -934,10 +946,12 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
device const uint8_t * q = x[i].qs;
#if QK_K == 256
const float d = x[i].d; const float d = x[i].d;
const float min = x[i].dmin; const float min = x[i].dmin;
device const uint8_t * q = x[i].qs;
device const uint8_t * scales = x[i].scales; device const uint8_t * scales = x[i].scales;
int is = 0; int is = 0;
@ -949,6 +963,14 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2; q += 32; is += 2;
} }
#else
float4 d4 = (float4)x[i].d;
for (int l = 0; l < 32; ++l) {
y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1];
y[l+32] = d4[2] * (q[l] >> 4) - d4[3];
}
y += QK_K;
#endif
} }
} }
@ -1332,11 +1354,15 @@ kernel void kernel_mul_mat_q4_k_f32(
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10; device const float * yy = (device const float *) src1 + r1*ne10;
const int nth = tptg.x*tptg.y; float sumf = 0;
const int ith = tptg.y*tpitg.x + tpitg.y;
#if QK_K == 256
const int tid = tpitg.y; // 0...16 const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3 const int il = tid/4; // 0...3
@ -1350,11 +1376,8 @@ kernel void kernel_mul_mat_q4_k_f32(
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0; const int y_offset = 64*im + l0;
sum[ith] = 0.0f;
uchar2 sc1, sc2, sc3, sc4; uchar2 sc1, sc2, sc3, sc4;
float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) { for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q1 = (x + i)->qs + q_offset; device const uint8_t * q1 = (x + i)->qs + q_offset;
@ -1383,6 +1406,22 @@ kernel void kernel_mul_mat_q4_k_f32(
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
} }
#else
const int il = 4*tpitg.x;
for (int i = tpitg.y; i < nb; i += tptg.y) {
device const uint8_t * q = x[i].qs + il;
device const float * y = yy + i * QK_K + il;
const float4 d4 = (float4)x[i].d;
for (int l = 0; l < 4; ++l) {
sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16])
+ d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]);
}
}
#endif
sum[ith] = sumf; sum[ith] = sumf;
@ -1593,12 +1632,14 @@ kernel void kernel_mul_mat_q6_k_f32(
const float d = x[i].d; const float d = x[i].d;
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32);
+ y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32);
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32);
+ y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32);
} }
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
} }
#endif #endif