diff --git a/ggml-metal.metal b/ggml-metal.metal index 7185b7f2c..0a170531b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -808,7 +808,8 @@ typedef struct { #if QK_K == 64 typedef struct { - half4 d; // super-block scales/mins + half d[2]; // super-block scales/mins + uint8_t scales[2]; uint8_t qs[QK_K/2]; // 4-bit quants } block_q4_K; #else @@ -996,7 +997,6 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; - for (int i = 0; i < nb; i++) { device const uint8_t * q = x[i].qs; @@ -1017,10 +1017,16 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i q += 32; is += 2; } #else - float4 d4 = (float4)x[i].d; + device const uint8_t * s = x[i].scales; + device const half2 * dh = (device const half2 *)x[i].d; + const float2 d = (float2)dh[0]; + const float d1 = d[0] * (s[0] & 0xF); + const float d2 = d[0] * (s[1] & 0xF); + const float m1 = d[1] * (s[0] >> 4); + const float m2 = d[1] * (s[1] >> 4); for (int l = 0; l < 32; ++l) { - y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1]; - y[l+32] = d4[2] * (q[l] >> 4) - d4[3]; + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; } y += QK_K; #endif @@ -1529,6 +1535,9 @@ kernel void kernel_mul_mat_q4_K_f32( } #else + uint16_t aux16[2]; + thread const uint8_t * scales = (thread const uint8_t *)aux16; + const int il = 4*tpitg.x; for (int i = tpitg.y; i < nb; i += tptg.y) { @@ -1536,11 +1545,16 @@ kernel void kernel_mul_mat_q4_K_f32( device const uint8_t * q = x[i].qs + il; device const float * y = yy + i * QK_K + il; - const float4 d4 = (float4)x[i].d; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; for (int l = 0; l < 4; ++l) { - sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16]) - + d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]); + sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) + + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); } } #endif