k_quants: WIP super-blocks with 64 weights
* We are able to pass preprocessor macros to the Metal compiler * Q6_K works and is actually slightly more efficient than the QK_K = 256 version (25.2 ms vs 25.8 ms)
This commit is contained in:
parent
fae24afd01
commit
e1bbcfc5cb
2 changed files with 49 additions and 3 deletions
|
@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
|
MTLCompileOptions* options = [MTLCompileOptions new];
|
||||||
|
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
||||||
|
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
||||||
|
#else
|
||||||
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
||||||
|
#endif
|
||||||
if (error) {
|
if (error) {
|
||||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
exit(1);
|
exit(1);
|
||||||
|
|
|
@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (ith == 0) {
|
if (ith == 0) {
|
||||||
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||||
dst[r1*ne0 + r0] = sum[0];
|
dst[r1*ne0 + r0] = sum[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (ith == 0) {
|
if (ith == 0) {
|
||||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||||
dst[r1*ne0 + r0] = sum[0];
|
dst[r1*ne0 + r0] = sum[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -775,7 +775,11 @@ kernel void kernel_cpy_f32_f32(
|
||||||
|
|
||||||
//============================================ k-quants ======================================================
|
//============================================ k-quants ======================================================
|
||||||
|
|
||||||
|
#ifndef QK_K
|
||||||
#define QK_K 256
|
#define QK_K 256
|
||||||
|
#else
|
||||||
|
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||||
|
@ -988,6 +992,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
for (int n = 0; n < QK_K; n += 128) {
|
for (int n = 0; n < QK_K; n += 128) {
|
||||||
for (int l = 0; l < 32; ++l) {
|
for (int l = 0; l < 32; ++l) {
|
||||||
int is = l/16;
|
int is = l/16;
|
||||||
|
@ -1005,6 +1010,19 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
||||||
qh += 32;
|
qh += 32;
|
||||||
sc += 8;
|
sc += 8;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
for (int l = 0; l < 16; ++l) {
|
||||||
|
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||||
|
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||||
|
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||||
|
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||||
|
y[l+ 0] = d * sc[0] * q1;
|
||||||
|
y[l+16] = d * sc[1] * q2;
|
||||||
|
y[l+32] = d * sc[2] * q3;
|
||||||
|
y[l+48] = d * sc[3] * q4;
|
||||||
|
}
|
||||||
|
y += 64;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1528,6 +1546,9 @@ kernel void kernel_mul_mat_q6_k_f32(
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
||||||
const int iqs = 16 * tpitg.y;
|
const int iqs = 16 * tpitg.y;
|
||||||
const int ip = iqs / 128; // 0 or 1
|
const int ip = iqs / 128; // 0 or 1
|
||||||
|
@ -1540,7 +1561,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
||||||
const int q_offset_l = 64*ip + l0;
|
const int q_offset_l = 64*ip + l0;
|
||||||
const int q_offset_h = 32*ip + l0;
|
const int q_offset_h = 32*ip + l0;
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||||
|
|
||||||
device const uint8_t * ql = x[i].ql + q_offset_l;
|
device const uint8_t * ql = x[i].ql + q_offset_l;
|
||||||
|
@ -1562,6 +1582,26 @@ kernel void kernel_mul_mat_q6_k_f32(
|
||||||
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
||||||
|
|
||||||
|
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||||
|
device const float * y = yy + i * QK_K + il;
|
||||||
|
device const uint8_t * ql = x[i].ql + il;
|
||||||
|
device const uint8_t * qh = x[i].qh + il;
|
||||||
|
device const int8_t * s = x[i].scales;
|
||||||
|
|
||||||
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32)
|
||||||
|
+ y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32)
|
||||||
|
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32)
|
||||||
|
+ y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
sum[ith] = sumf;
|
sum[ith] = sumf;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue