metal: use uint16_t instead of uint8_t.

Apple GPU doesn't like uint8_t. For every operation on uint8_t
the gpu need to copy the uint8_t to an empty 16 bit register, then
it can issue other instructions.

For the matrix-vector multiplication kernel only, we observed a
340~350 GB/s memory read speed on M1 Max after this commit, which is
very close to the reported hardware limit.
This commit is contained in:
lshzh-ww 2023-07-15 02:15:02 -04:00
parent a6803cab94
commit bbce392890

View file

@ -8,14 +8,14 @@ using namespace metal;
#define QR4_0 2
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
uint16_t qs[QK4_0 / 4]; // nibbles / quants
} block_q4_0;
#define QK4_1 32
typedef struct {
half d; // delta
half m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
uint16_t qs[QK4_1 / 4]; // nibbles / quants
} block_q4_1;
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
@ -28,12 +28,16 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
for (int i = 0; i < nb; i++) {
const half d = x[i].d;
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
for (int j = 0; j < qk/4; ++j) {
const int x0 = (x[i].qs[j] & 0x000F) - 8;
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8;
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8;
const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
y[i*qk + 2 * j + 0 ] = x0*d;
y[i*qk + 2 * j + qk/2 ] = x1*d;
y[i*qk + 2 * j + 1 ] = x2*d;
y[i*qk + 2 * j + 1 + qk/2] = x3*d;
}
}
}
@ -49,12 +53,16 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
const half d = x[i].d;
const half m = x[i].m;
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F);
const int x1 = (x[i].qs[j] >> 4);
for (int j = 0; j < qk/4; ++j) {
const int x0 = (x[i].qs[j] & 0x000F);
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4);
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8);
const int x3 = ((x[i].qs[j] & 0xF000) >> 12);
y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
y[i*qk + 2 * j + 0 ] = x0*d + m;
y[i*qk + 2 * j + qk/2 ] = x1*d + m;
y[i*qk + 2 * j + 1 ] = x2*d + m;
y[i*qk + 2 * j + 1 + qk/2] = x3*d + m;
}
}
}
@ -401,6 +409,10 @@ kernel void kernel_mul_mat_q4_0_f32(
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -409,8 +421,9 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
sumf[row] += d * acc;
qb_curr = qb_next;
@ -432,6 +445,9 @@ kernel void kernel_mul_mat_q4_0_f32(
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -440,8 +456,9 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc;
@ -487,6 +504,10 @@ kernel void kernel_mul_mat_q4_1_f32(
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -496,8 +517,9 @@ kernel void kernel_mul_mat_q4_1_f32(
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
sumf[row] += d * acc + m * sumy;
qb_curr = qb_next;
@ -518,6 +540,9 @@ kernel void kernel_mul_mat_q4_1_f32(
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -527,8 +552,9 @@ kernel void kernel_mul_mat_q4_1_f32(
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc + m * sumy;