k_quants: WIP super-blocks with 64 weights

Q2_K works on Metal and is very slightly faster
than QK_K = 256 (23.8 ms vs 24.2 ms).
This commit is contained in:
Iwan Kawrakow 2023-06-23 13:04:30 +03:00
parent 167a0bbe34
commit 6081a65527

View file

@ -863,6 +863,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
device const uint8_t * q = x[i].qs;
#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
@ -881,6 +882,19 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
}
q += 32;
}
#else
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
for (int l = 0; l < 16; ++l) {
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
}
y += QK_K;
#endif
}
}
@ -1153,6 +1167,9 @@ kernel void kernel_mul_mat_q2_k_f32(
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
float sumf = 0;
#if QK_K == 256
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid%4; // 0...3
@ -1165,9 +1182,6 @@ kernel void kernel_mul_mat_q2_k_f32(
const int y_offset = 64*il + n*ir;
const int q_offset = 32*ip + n*ir;
sum[ith] = 0.0f;
float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q = x[i].qs + q_offset;
@ -1180,7 +1194,6 @@ kernel void kernel_mul_mat_q2_k_f32(
device const float * y = yy + i*QK_K + y_offset;
//float4 s = {0.f, 0.f, 0.f, 0.f};
float2 s = {0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
@ -1195,25 +1208,38 @@ kernel void kernel_mul_mat_q2_k_f32(
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
}
#else
const int il = 4 * tpitg.x;
uint32_t aux[2];
thread const uint8_t * d = (thread const uint8_t *)aux;
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
for (int i = tpitg.y; i < nb; i += tptg.y) {
device const uint8_t * q = x[i].qs + il;
device const float * y = yy + i*QK_K + il;
const float dall = (float)x[i].d;
const float dmin = (float)x[i].dmin;
device const uint32_t * a = (device const uint32_t *)x[i].scales;
aux[0] = a[0] & 0x0f0f0f0f;
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
for (int l = 0; l < 4; ++l) {
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
}
}
#endif
sum[ith] = sumf;
//int mask1 = (ith%4 == 0);
//int mask2 = (ith%16 == 0);
//threadgroup_barrier(mem_flags::mem_threadgroup);
//for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
//threadgroup_barrier(mem_flags::mem_threadgroup);
//for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
//threadgroup_barrier(mem_flags::mem_threadgroup);
//if (ith == 0) {
// for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
// dst[r1*ne0 + r0] = sum[0];
//}
//
// Accumulate the sum from all threads in the threadgroup
// This version is slightly faster than the commented out one below,
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {