k_quants: WIP super-blocks with 64 weights
Q5_K works on Metal and is slightly faster than QK_K = 256 (23.7 ms vs 26.3 ms).
This commit is contained in:
parent
ff83e32c6a
commit
285eeb1531
1 changed files with 60 additions and 5 deletions
|
@ -820,6 +820,13 @@ typedef struct {
|
|||
} block_q4_k;
|
||||
#endif
|
||||
|
||||
#if QK_K == 64
|
||||
typedef struct {
|
||||
half4 d; // super-block scales/mins
|
||||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_k;
|
||||
#else
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
|
@ -828,6 +835,7 @@ typedef struct {
|
|||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_k;
|
||||
// 176 bytes / block
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
@ -1024,6 +1032,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
#if QK_K == 256
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
const float d = (float)(x[i].d);
|
||||
|
@ -1044,6 +1053,27 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|||
u1 <<= 2; u2 <<= 2;
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
const float4 d = (float4)x[i].d;
|
||||
|
||||
device const uint8_t * ql = x[i].qs;
|
||||
device const uint8_t * qh = x[i].qh;
|
||||
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1];
|
||||
y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1];
|
||||
y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1];
|
||||
y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1];
|
||||
y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3];
|
||||
y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3];
|
||||
y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3];
|
||||
y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3];
|
||||
}
|
||||
y += QK_K;
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
@ -1562,10 +1592,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
|
@ -1577,6 +1603,14 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
#if QK_K == 256
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid = tpitg.y; // 0...16
|
||||
const int il = tid/4; // 0...3
|
||||
const int ir = tid - 4*il;// 0...3
|
||||
|
@ -1596,7 +1630,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|||
|
||||
uchar2 sc1, sc2, sc3, sc4;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
|
||||
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
||||
|
@ -1628,6 +1661,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|||
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
||||
|
||||
}
|
||||
#else
|
||||
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
||||
const int im = il/8; // 0, 0, 1, 1
|
||||
const int in = il%8; // 0, 4, 0, 4
|
||||
|
||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||
|
||||
device const uint8_t * q = x[i].qs + il;
|
||||
device const uint8_t * h = x[i].qh + in;
|
||||
device const float * y = yy + i*QK_K + il;
|
||||
|
||||
const float4 d = (float4)x[i].d;
|
||||
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t hl = h[l] >> im;
|
||||
sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1])
|
||||
+ y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1])
|
||||
+ y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3])
|
||||
+ y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
sum[ith] = sumf;
|
||||
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue