k_quants: forgot to add the Metal changes in last commit

This commit is contained in:
Iwan Kawrakow 2023-06-24 15:47:27 +03:00
parent ce19b965f0
commit 4f61506929

View file

@ -808,7 +808,8 @@ typedef struct {
#if QK_K == 64 #if QK_K == 64
typedef struct { typedef struct {
half4 d; // super-block scales/mins half d[2]; // super-block scales/mins
uint8_t scales[2];
uint8_t qs[QK_K/2]; // 4-bit quants uint8_t qs[QK_K/2]; // 4-bit quants
} block_q4_K; } block_q4_K;
#else #else
@ -996,7 +997,6 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
device const uint8_t * q = x[i].qs; device const uint8_t * q = x[i].qs;
@ -1017,10 +1017,16 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i
q += 32; is += 2; q += 32; is += 2;
} }
#else #else
float4 d4 = (float4)x[i].d; device const uint8_t * s = x[i].scales;
device const half2 * dh = (device const half2 *)x[i].d;
const float2 d = (float2)dh[0];
const float d1 = d[0] * (s[0] & 0xF);
const float d2 = d[0] * (s[1] & 0xF);
const float m1 = d[1] * (s[0] >> 4);
const float m2 = d[1] * (s[1] >> 4);
for (int l = 0; l < 32; ++l) { for (int l = 0; l < 32; ++l) {
y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1]; y[l+ 0] = d1 * (q[l] & 0xF) - m1;
y[l+32] = d4[2] * (q[l] >> 4) - d4[3]; y[l+32] = d2 * (q[l] >> 4) - m2;
} }
y += QK_K; y += QK_K;
#endif #endif
@ -1529,6 +1535,9 @@ kernel void kernel_mul_mat_q4_K_f32(
} }
#else #else
uint16_t aux16[2];
thread const uint8_t * scales = (thread const uint8_t *)aux16;
const int il = 4*tpitg.x; const int il = 4*tpitg.x;
for (int i = tpitg.y; i < nb; i += tptg.y) { for (int i = tpitg.y; i < nb; i += tptg.y) {
@ -1536,11 +1545,16 @@ kernel void kernel_mul_mat_q4_K_f32(
device const uint8_t * q = x[i].qs + il; device const uint8_t * q = x[i].qs + il;
device const float * y = yy + i * QK_K + il; device const float * y = yy + i * QK_K + il;
const float4 d4 = (float4)x[i].d; const float d = (float)x[i].d[0];
const float m = (float)x[i].d[1];
device const uint16_t * a = (device const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16]) sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
+ d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]); + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
} }
} }
#endif #endif