k_quants: WIP super-blocks with 64 weights
Q4_K works on Metal and is actually slightly faster than QK_K = 256 (21.95 ms vs 24.0 ms).
This commit is contained in:
parent
e1bbcfc5cb
commit
167a0bbe34
1 changed files with 53 additions and 12 deletions
|
@ -781,6 +781,12 @@ kernel void kernel_cpy_f32_f32(
|
||||||
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
#define K_SCALE_SIZE 12
|
||||||
|
#else
|
||||||
|
#define K_SCALE_SIZE 4
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||||
uint8_t qs[QK_K/4]; // quants
|
uint8_t qs[QK_K/4]; // quants
|
||||||
|
@ -797,13 +803,19 @@ typedef struct {
|
||||||
} block_q3_k;
|
} block_q3_k;
|
||||||
// 110 bytes / block
|
// 110 bytes / block
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
typedef struct {
|
||||||
|
half4 d; // super-block scales/mins
|
||||||
|
uint8_t qs[QK_K/2]; // 4-bit quants
|
||||||
|
} block_q4_k;
|
||||||
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
half dmin; // super-block scale for quantized mins
|
half dmin; // super-block scale for quantized mins
|
||||||
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_k;
|
} block_q4_k;
|
||||||
// 144 bytes / block
|
#endif
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
|
@ -934,10 +946,12 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
device const uint8_t * q = x[i].qs;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
const float min = x[i].dmin;
|
const float min = x[i].dmin;
|
||||||
|
|
||||||
device const uint8_t * q = x[i].qs;
|
|
||||||
device const uint8_t * scales = x[i].scales;
|
device const uint8_t * scales = x[i].scales;
|
||||||
|
|
||||||
int is = 0;
|
int is = 0;
|
||||||
|
@ -949,6 +963,14 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
||||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||||
q += 32; is += 2;
|
q += 32; is += 2;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
float4 d4 = (float4)x[i].d;
|
||||||
|
for (int l = 0; l < 32; ++l) {
|
||||||
|
y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1];
|
||||||
|
y[l+32] = d4[2] * (q[l] >> 4) - d4[3];
|
||||||
|
}
|
||||||
|
y += QK_K;
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1332,11 +1354,15 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
|
const int nth = tptg.x*tptg.y;
|
||||||
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
float sumf = 0;
|
||||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
|
||||||
|
#if QK_K == 256
|
||||||
|
|
||||||
const int tid = tpitg.y; // 0...16
|
const int tid = tpitg.y; // 0...16
|
||||||
const int il = tid/4; // 0...3
|
const int il = tid/4; // 0...3
|
||||||
|
@ -1350,11 +1376,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
const int q_offset = 32*im + l0;
|
const int q_offset = 32*im + l0;
|
||||||
const int y_offset = 64*im + l0;
|
const int y_offset = 64*im + l0;
|
||||||
|
|
||||||
sum[ith] = 0.0f;
|
|
||||||
|
|
||||||
uchar2 sc1, sc2, sc3, sc4;
|
uchar2 sc1, sc2, sc3, sc4;
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||||
|
|
||||||
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
||||||
|
@ -1383,6 +1406,22 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const int il = 4*tpitg.x;
|
||||||
|
|
||||||
|
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||||
|
|
||||||
|
device const uint8_t * q = x[i].qs + il;
|
||||||
|
device const float * y = yy + i * QK_K + il;
|
||||||
|
|
||||||
|
const float4 d4 = (float4)x[i].d;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16])
|
||||||
|
+ d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
sum[ith] = sumf;
|
sum[ith] = sumf;
|
||||||
|
|
||||||
|
@ -1593,12 +1632,14 @@ kernel void kernel_mul_mat_q6_k_f32(
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32)
|
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32);
|
||||||
+ y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32)
|
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32);
|
||||||
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32)
|
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32);
|
||||||
+ y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32);
|
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32);
|
||||||
}
|
}
|
||||||
|
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue