k_quants: WIP super-blocks with 64 weights

Q3_K works on Metal and is slightly faster
than QK_K = 256 (26.6 ms vs 28.3 ms).
This commit is contained in:
Iwan Kawrakow 2023-06-23 13:34:02 +03:00
parent 6081a65527
commit ff83e32c6a

View file

@ -798,10 +798,13 @@ typedef struct {
typedef struct { typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits #if QK_K == 64
half d; // super-block scale int8_t scales[K_SCALE_SIZE];
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
half d; // super-block scale
} block_q3_k; } block_q3_k;
// 110 bytes / block
#if QK_K == 64 #if QK_K == 64
typedef struct { typedef struct {
@ -903,6 +906,8 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
#if QK_K == 256
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
@ -948,8 +953,34 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
} }
q += 32; q += 32;
} }
} }
#else
for (int i = 0; i < nb; i++) {
const float d_all = (float)(x[i].d);
device const uint8_t * q = x[i].qs;
device const uint8_t * hm = x[i].hmask;
const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];
for (int l = 0; l < 8; ++l) {
uint8_t h = hm[l];
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
}
y += QK_K;
}
#endif
} }
@ -1286,6 +1317,8 @@ kernel void kernel_mul_mat_q3_k_f32(
const int nth = tptg.x*tptg.y; const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y; const int ith = tptg.y*tpitg.x + tpitg.y;
#if QK_K == 256
const int tid = tpitg.y; // expecting 16 const int tid = tpitg.y; // expecting 16
const int ip = tid/8; // 0 or 1 const int ip = tid/8; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3 const int il = tid/2 - 4*ip; // 0...3
@ -1339,6 +1372,39 @@ kernel void kernel_mul_mat_q3_k_f32(
//sum[ith] = sumf; //sum[ith] = sumf;
sum[ith] = sumf1 - 32.f*sumf2; sum[ith] = sumf1 - 32.f*sumf2;
#else
const int il = 4 * tpitg.x; // 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
const int in = il%8; // 0, 4, 0, 4
float sumf = 0;
for (int i = tpitg.y; i < nb; i += tptg.y) {
const float d_all = (float)(x[i].d);
device const uint8_t * q = x[i].qs + il;
device const uint8_t * h = x[i].hmask + in;
device const float * y = yy + i * QK_K + il;
const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];
for (int l = 0; l < 4; ++l) {
const uint8_t hm = h[l] >> im;
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
}
}
sum[ith] = sumf;
#endif
// //
// Accumulate the sum from all threads in the threadgroup // Accumulate the sum from all threads in the threadgroup
@ -1371,10 +1437,6 @@ kernel void kernel_mul_mat_q4_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]], uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) { uint2 tptg[[threads_per_threadgroup]]) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int nb = ne00/QK_K; const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
@ -1390,6 +1452,10 @@ kernel void kernel_mul_mat_q4_k_f32(
#if QK_K == 256 #if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = tpitg.y; // 0...16 const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3 const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3 const int ir = tid - 4*il;// 0...3
@ -1660,10 +1726,10 @@ kernel void kernel_mul_mat_q6_k_f32(
float4 sums = {0.f, 0.f, 0.f, 0.f}; float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32); sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32); sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
} }
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
} }