metal : q3_K still not working
Adding a heavily commented q3_K metal kernel to explain my obviously faulty logic. Perhaps someone could spot the issue?
This commit is contained in:
parent
3b4f5e167c
commit
982c7cf5cc
1 changed files with 59 additions and 64 deletions
123
ggml-metal.metal
123
ggml-metal.metal
|
@ -634,6 +634,7 @@ typedef struct {
|
|||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
} block_q2_k;
|
||||
// 84 bytes / block
|
||||
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||
|
@ -641,6 +642,7 @@ typedef struct {
|
|||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||
half d; // super-block scale
|
||||
} block_q3_k;
|
||||
// 110 bytes / block
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
|
@ -648,6 +650,7 @@ typedef struct {
|
|||
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_k;
|
||||
// 144 bytes / block
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
|
@ -656,6 +659,7 @@ typedef struct {
|
|||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_k;
|
||||
// 176 bytes / block
|
||||
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
@ -663,6 +667,7 @@ typedef struct {
|
|||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
half d; // super-block scale
|
||||
} block_q6_k;
|
||||
// 210 bytes / block
|
||||
|
||||
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||
uchar4 r;
|
||||
|
@ -723,6 +728,7 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
uint32_t aux[4];
|
||||
thread const int8_t * scales = (thread const int8_t *)aux;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
|
@ -739,25 +745,13 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|||
aux[0] = (a[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
||||
aux[1] = (a[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
char4 scales;
|
||||
|
||||
int shift1 = 0;
|
||||
|
||||
int ia = 0;
|
||||
int is = 4;
|
||||
int is = 0;
|
||||
float dl;
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
|
||||
int shift = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
|
||||
const int shift2 = 2*(j/2);
|
||||
|
||||
if (is == 4) {
|
||||
scales = as_type<char4>(aux[ia++]);
|
||||
is = 0;
|
||||
}
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < 16; ++l) {
|
||||
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
|
||||
|
@ -772,7 +766,6 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|||
m <<= 1;
|
||||
}
|
||||
q += 32;
|
||||
shift1 += 4;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1051,29 +1044,17 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
uint32_t aux[2];
|
||||
|
||||
const uint8_t m1 = 1;
|
||||
const uint8_t m3 = 3;
|
||||
const int8_t m4 = 4;
|
||||
|
||||
|
@ -1088,61 +1069,75 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
const int step = QK_K / tptg.y; // we expect this to be 16
|
||||
const int iqs = step * tpitg.y; // 0...240 in steps of 16
|
||||
const int ip = iqs / 128; // 0 or 1
|
||||
const int il = (iqs - 128*ip)/16; // 0...7
|
||||
const int tid = tpitg.y;
|
||||
const int il = tid/4; // 0...3 0 -> 0...63, 1 -> 64...127, 2 -> 128...191, 3 -> 192...255
|
||||
const int ip = il / 2; // 0 or 1 0 -> use 1st 32 q's (0...127), 1 -> 2nd 32 (128...255)
|
||||
const int is = il % 2; // 0 or 1 0 -> 0...63, 128...191, 1 -> 64...127, 192...255
|
||||
const int ir = tid - 4*il; // 0...3
|
||||
const int n = 4;
|
||||
const int l0 = n * il;
|
||||
const int is = l0/16;
|
||||
const uint8_t m = m1 << (4*ip);
|
||||
const uchar4 mask = {m, (uint8_t)(m << 1), (uint8_t)(m << 2), (uint8_t)(m << 3)};
|
||||
const int l0 = n * ir; // first index for this thread within a group of 32 (0, 4, 8, 12)
|
||||
// 0...31 use 1<<0, 32...63 use 1<<1, 64...95 use 1<<2, 96...128 use 1<<3, etc.
|
||||
// we process 64*il...64*il+63 -> 1st mask is 1<<(2*il), second is 1<<(2*il+1)
|
||||
// masks for high bit
|
||||
const uint8_t m = 1 << (2*il);
|
||||
const uchar2 mask = {m, (uint8_t)(m << 1)};
|
||||
|
||||
const int shift1 = 4*ip;
|
||||
const int shift2 = shift1 + 2;
|
||||
const int shift1 = 4*ip; // 1st shift for scale. must be 0 (0...127) or 4 (128...255)
|
||||
const int shift2 = 2*il; // 2nd shift for scale. 0, 2, 4, or 6
|
||||
// 1st shift for quants must be 0 in 0...31, 2 in 32...64, 4 in 64...96, 6 in 96...128, then agsin 0, 2, etc.
|
||||
const int shift3 = 4*is;
|
||||
const int shift4 = shift3 + 2;
|
||||
|
||||
//int8_t sc[4];
|
||||
const int q_offset = 32*ip + l0;
|
||||
const int y_offset = 64*il + l0;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
|
||||
// Copied from the C de-quantization code
|
||||
//aux[0] = ((a[0] >> 0) & kmask2) | (((a[2] >> 0) & kmask1) << 4);
|
||||
//aux[1] = ((a[1] >> 0) & kmask2) | (((a[2] >> 2) & kmask1) << 4);
|
||||
//aux[2] = ((a[0] >> 4) & kmask2) | (((a[2] >> 4) & kmask1) << 4);
|
||||
//aux[3] = ((a[1] >> 4) & kmask2) | (((a[2] >> 6) & kmask1) << 4);
|
||||
|
||||
//// 0....63 we need a[0] with shift=0, a[2] with shift 0
|
||||
//// 64...127 we need a[1] with shift=0, a[2] with shift 2
|
||||
////128...191 we need a[0] with shift=4, a[2] with shift 4
|
||||
////192...255 we need a[1] with shift=4, a[2] with shift 6
|
||||
//// a[is] >> (4*ip) & 0xF | a[2] >> (2*il) & 3
|
||||
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
||||
uint32_t tmp = a[2];
|
||||
aux[0] = ((a[0] >> shift1) & kmask2) | (((tmp >> shift1) & kmask1) << 4);
|
||||
aux[1] = ((a[1] >> shift1) & kmask2) | (((tmp >> shift2) & kmask1) << 4);
|
||||
const char4 sc = as_type<char4>(((a[is] >> shift1) & kmask2) | (((a[2] >> shift2) & kmask1) << 4));
|
||||
|
||||
device const uint8_t * q = x[i].qs + 32*ip + l0;
|
||||
device const uint8_t * hm = x[i].hmask + l0;
|
||||
//device const uint8_t * scales = x[i].scales;
|
||||
// Here I was thinking "what if the above is not processed correctly because x[i].scales is not 4-byte
|
||||
// aligned?". If that was the issue, using a uint16_t pointer should solve it as x[i].scales is 2-byte aligned.
|
||||
// It does not solve the problem, it just makes it run slower.
|
||||
//device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
||||
//const char2 sc1 = as_type<char2>((uint16_t)(((a[2*is+0] >> shift1) & kmask2) | (((a[4] >> shift2) & kmask1) << 4)));
|
||||
//const char2 sc2 = as_type<char2>((uint16_t)(((a[2*is+1] >> shift1) & kmask2) | (((a[5] >> shift2) & kmask1) << 4)));
|
||||
|
||||
device const float * y = yy + i * QK_K + 128*ip + l0;
|
||||
device const uint8_t * q = x[i].qs + q_offset;
|
||||
device const uint8_t * h = x[i].hmask + l0;
|
||||
|
||||
const float dall = x[i].d;
|
||||
device const float * y = yy + i * QK_K + y_offset;
|
||||
|
||||
const char4 sc1 = as_type<char4>(aux[0]);
|
||||
const char4 sc2 = as_type<char4>(aux[1]);
|
||||
|
||||
//sc[0] = ((scales[is+0] >> shift1) & 0xF) | (((scales[is+ 8] >> shift1) & m3) << 4);
|
||||
//sc[1] = ((scales[is+2] >> shift1) & 0xF) | (((scales[is+10] >> shift1) & m3) << 4);
|
||||
//sc[2] = ((scales[is+4] >> shift1) & 0xF) | (((scales[is+ 8] >> shift2) & m3) << 4);
|
||||
//sc[3] = ((scales[is+6] >> shift1) & 0xF) | (((scales[is+10] >> shift2) & m3) << 4);
|
||||
const float dall = (float)x[i].d;
|
||||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sums[0] += y[l+ 0] * ((int8_t)((q[l] >> 0) & m3) - ((hm[l] & mask[0]) ? 0 : m4));
|
||||
sums[1] += y[l+32] * ((int8_t)((q[l] >> 2) & m3) - ((hm[l] & mask[1]) ? 0 : m4));
|
||||
sums[2] += y[l+64] * ((int8_t)((q[l] >> 4) & m3) - ((hm[l] & mask[2]) ? 0 : m4));
|
||||
sums[3] += y[l+96] * ((int8_t)((q[l] >> 6) & m3) - ((hm[l] & mask[3]) ? 0 : m4));
|
||||
sums[0] += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift3) & m3) - ((h[l+ 0] & mask[0]) ? 0 : m4));
|
||||
sums[1] += y[l+16] * ((int8_t)((q[l+16] >> shift3) & m3) - ((h[l+16] & mask[0]) ? 0 : m4));
|
||||
sums[2] += y[l+32] * ((int8_t)((q[l+ 0] >> shift4) & m3) - ((h[l+ 0] & mask[1]) ? 0 : m4));
|
||||
sums[3] += y[l+48] * ((int8_t)((q[l+16] >> shift4) & m3) - ((h[l+16] & mask[1]) ? 0 : m4));
|
||||
}
|
||||
|
||||
//sumf += dall * (sums[0] * (sc[0] - 32)
|
||||
// + sums[1] * (sc[1] - 32)
|
||||
// + sums[2] * (sc[2] - 32)
|
||||
// + sums[3] * (sc[3] - 32));
|
||||
sumf += dall * (sums[0] * (sc1[is+0] - 32)
|
||||
+ sums[1] * (sc1[is+2] - 32)
|
||||
+ sums[2] * (sc2[is+0] - 32)
|
||||
+ sums[3] * (sc2[is+2] - 32));
|
||||
sumf += dall * (sums[0] * (sc[0] - 32)
|
||||
+ sums[1] * (sc[1] - 32)
|
||||
+ sums[2] * (sc[2] - 32)
|
||||
+ sums[3] * (sc[3] - 32));
|
||||
//sumf += dall * (sums[0] * (sc1[0] - 32)
|
||||
// + sums[1] * (sc1[1] - 32)
|
||||
// + sums[2] * (sc2[0] - 32)
|
||||
// + sums[3] * (sc2[1] - 32));
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue