k_quants: WIP super-blocks with 64 weights
Q2_K works on Metal and is very slightly faster than QK_K = 256 (23.8 ms vs 24.2 ms).
This commit is contained in:
parent
167a0bbe34
commit
6081a65527
1 changed files with 45 additions and 19 deletions
|
@ -863,6 +863,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
||||||
|
|
||||||
device const uint8_t * q = x[i].qs;
|
device const uint8_t * q = x[i].qs;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
int is = 0;
|
int is = 0;
|
||||||
float dl, ml;
|
float dl, ml;
|
||||||
for (int n = 0; n < QK_K; n += 128) {
|
for (int n = 0; n < QK_K; n += 128) {
|
||||||
|
@ -881,6 +882,19 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
||||||
}
|
}
|
||||||
q += 32;
|
q += 32;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
||||||
|
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
||||||
|
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
||||||
|
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
||||||
|
for (int l = 0; l < 16; ++l) {
|
||||||
|
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
||||||
|
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
||||||
|
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
||||||
|
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
||||||
|
}
|
||||||
|
y += QK_K;
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1153,6 +1167,9 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
const int tid = tpitg.y; // 0...16
|
const int tid = tpitg.y; // 0...16
|
||||||
const int il = tid/4; // 0...3
|
const int il = tid/4; // 0...3
|
||||||
const int ir = tid%4; // 0...3
|
const int ir = tid%4; // 0...3
|
||||||
|
@ -1165,9 +1182,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
const int y_offset = 64*il + n*ir;
|
const int y_offset = 64*il + n*ir;
|
||||||
const int q_offset = 32*ip + n*ir;
|
const int q_offset = 32*ip + n*ir;
|
||||||
|
|
||||||
sum[ith] = 0.0f;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||||
|
|
||||||
device const uint8_t * q = x[i].qs + q_offset;
|
device const uint8_t * q = x[i].qs + q_offset;
|
||||||
|
@ -1180,7 +1194,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
|
|
||||||
device const float * y = yy + i*QK_K + y_offset;
|
device const float * y = yy + i*QK_K + y_offset;
|
||||||
|
|
||||||
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
float2 s = {0.f, 0.f};
|
float2 s = {0.f, 0.f};
|
||||||
float smin = 0;
|
float smin = 0;
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
|
@ -1195,25 +1208,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const int il = 4 * tpitg.x;
|
||||||
|
|
||||||
|
uint32_t aux[2];
|
||||||
|
thread const uint8_t * d = (thread const uint8_t *)aux;
|
||||||
|
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
||||||
|
|
||||||
|
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||||
|
|
||||||
|
device const uint8_t * q = x[i].qs + il;
|
||||||
|
device const float * y = yy + i*QK_K + il;
|
||||||
|
|
||||||
|
const float dall = (float)x[i].d;
|
||||||
|
const float dmin = (float)x[i].dmin;
|
||||||
|
|
||||||
|
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
||||||
|
aux[0] = a[0] & 0x0f0f0f0f;
|
||||||
|
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
||||||
|
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
||||||
|
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
||||||
|
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
sum[ith] = sumf;
|
sum[ith] = sumf;
|
||||||
|
|
||||||
//int mask1 = (ith%4 == 0);
|
|
||||||
//int mask2 = (ith%16 == 0);
|
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//if (ith == 0) {
|
|
||||||
// for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
|
||||||
// dst[r1*ne0 + r0] = sum[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Accumulate the sum from all threads in the threadgroup
|
// Accumulate the sum from all threads in the threadgroup
|
||||||
// This version is slightly faster than the commented out one below,
|
|
||||||
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
|
||||||
//
|
//
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (ith%4 == 0) {
|
if (ith%4 == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue