k_quants: WIP super-blocks with 64 weights

Q5_K works on Metal and is slightly faster
than QK_K = 256 (23.7 ms vs 26.3 ms).
This commit is contained in:
Iwan Kawrakow 2023-06-23 14:11:47 +03:00
parent ff83e32c6a
commit 285eeb1531

View file

@ -820,6 +820,13 @@ typedef struct {
} block_q4_k;
#endif
#if QK_K == 64
typedef struct {
half4 d; // super-block scales/mins
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_k;
#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
@ -828,6 +835,7 @@ typedef struct {
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_k;
// 176 bytes / block
#endif
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
@ -1024,6 +1032,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
assert(k % QK_K == 0);
const int nb = k / QK_K;
#if QK_K == 256
for (int i = 0; i < nb; i++) {
const float d = (float)(x[i].d);
@ -1044,6 +1053,27 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
u1 <<= 2; u2 <<= 2;
}
}
#else
for (int i = 0; i < nb; i++) {
const float4 d = (float4)x[i].d;
device const uint8_t * ql = x[i].qs;
device const uint8_t * qh = x[i].qh;
for (int l = 0; l < 8; ++l) {
y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1];
y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1];
y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1];
y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1];
y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3];
y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3];
y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3];
y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3];
}
y += QK_K;
}
#endif
}
@ -1562,10 +1592,6 @@ kernel void kernel_mul_mat_q5_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
@ -1577,6 +1603,14 @@ kernel void kernel_mul_mat_q5_k_f32(
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
float sumf = 0;
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
@ -1596,7 +1630,6 @@ kernel void kernel_mul_mat_q5_k_f32(
uchar2 sc1, sc2, sc3, sc4;
float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q1 = (x + i)->qs + q_offset;
@ -1628,6 +1661,28 @@ kernel void kernel_mul_mat_q5_k_f32(
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
}
#else
const int il = 4 * tpitg.x; // 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
const int in = il%8; // 0, 4, 0, 4
for (int i = tpitg.y; i < nb; i += tptg.y) {
device const uint8_t * q = x[i].qs + il;
device const uint8_t * h = x[i].qh + in;
device const float * y = yy + i*QK_K + il;
const float4 d = (float4)x[i].d;
for (int l = 0; l < 4; ++l) {
const uint8_t hl = h[l] >> im;
sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1])
+ y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1])
+ y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3])
+ y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]);
}
}
#endif
sum[ith] = sumf;
//