diff --git a/ggml-metal.metal b/ggml-metal.metal index f094a1d40..a09680b61 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -8,14 +8,14 @@ using namespace metal; #define QR4_0 2 typedef struct { half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants + uint16_t qs[QK4_0 / 4]; // nibbles / quants } block_q4_0; #define QK4_1 32 typedef struct { half d; // delta half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants + uint16_t qs[QK4_1 / 4]; // nibbles / quants } block_q4_1; static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { @@ -28,12 +28,16 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i for (int i = 0; i < nb; i++) { const half d = x[i].d; - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; + for (int j = 0; j < qk/4; ++j) { + const int x0 = (x[i].qs[j] & 0x000F) - 8; + const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8; + const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8; + const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8; - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; + y[i*qk + 2 * j + 0 ] = x0*d; + y[i*qk + 2 * j + qk/2 ] = x1*d; + y[i*qk + 2 * j + 1 ] = x2*d; + y[i*qk + 2 * j + 1 + qk/2] = x3*d; } } } @@ -49,12 +53,16 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i const half d = x[i].d; const half m = x[i].m; - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); + for (int j = 0; j < qk/4; ++j) { + const int x0 = (x[i].qs[j] & 0x000F); + const int x1 = ((x[i].qs[j] & 0x00F0) >> 4); + const int x2 = ((x[i].qs[j] & 0x0F00) >> 8); + const int x3 = ((x[i].qs[j] & 0xF000) >> 12); - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; + y[i*qk + 2 * j + 0 ] = x0*d + m; + y[i*qk + 2 * j + qk/2 ] = x1*d + m; + y[i*qk + 2 * j + 1 ] = x2*d + m; + y[i*qk + 2 * j + 1 + qk/2] = x3*d + m; } } } @@ -401,6 +409,10 @@ kernel void kernel_mul_mat_q4_0_f32( sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } sumy *= (-8.f); + // we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this. + for (int i = 0; i < 32; i++) { + yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16); + } for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -409,8 +421,9 @@ kernel void kernel_mul_mat_q4_0_f32( // calculate float d = qb_curr.d; float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int i = 0; i < 16; i+=2) { + acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0); + acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000); } sumf[row] += d * acc; qb_curr = qb_next; @@ -432,6 +445,9 @@ kernel void kernel_mul_mat_q4_0_f32( sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } sumy *= (-8.f); + for (int i = 0; i < 32; i++) { + yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16); + } for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -440,8 +456,9 @@ kernel void kernel_mul_mat_q4_0_f32( // calculate float d = qb_curr.d; float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int i = 0; i < 16; i+=2) { + acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0); + acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000); } if (tiisg < nb % N_SIMDWIDTH) { sumf[row] += d * acc; @@ -487,6 +504,10 @@ kernel void kernel_mul_mat_q4_1_f32( y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } + // we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this. + for (int i = 0; i < 32; i++) { + yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16); + } for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -496,8 +517,9 @@ kernel void kernel_mul_mat_q4_1_f32( const float d = qb_curr.d; const float m = qb_curr.m; float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int i = 0; i < 16; i+=2) { + acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0); + acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000); } sumf[row] += d * acc + m * sumy; qb_curr = qb_next; @@ -518,6 +540,9 @@ kernel void kernel_mul_mat_q4_1_f32( y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } + for (int i = 0; i < 32; i++) { + yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16); + } for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -527,8 +552,9 @@ kernel void kernel_mul_mat_q4_1_f32( const float d = qb_curr.d; const float m = qb_curr.m; float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int i = 0; i < 16; i+=2) { + acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0); + acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000); } if (tiisg < nb % N_SIMDWIDTH) { sumf[row] += d * acc + m * sumy;