metal: use uint16_t instead of uint8_t.
Apple GPU doesn't like uint8_t. For every operation on uint8_t the gpu need to copy the uint8_t to an empty 16 bit register, then it can issue other instructions. For the matrix-vector multiplication kernel only, we observed a 340~350 GB/s memory read speed on M1 Max after this commit, which is very close to the reported hardware limit.
This commit is contained in:
parent
a6803cab94
commit
bbce392890
1 changed files with 46 additions and 20 deletions
|
@ -8,14 +8,14 @@ using namespace metal;
|
|||
#define QR4_0 2
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
uint16_t qs[QK4_0 / 4]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
uint16_t qs[QK4_1 / 4]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
|
||||
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
||||
|
@ -28,12 +28,16 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
|
|||
for (int i = 0; i < nb; i++) {
|
||||
const half d = x[i].d;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
||||
const int x1 = (x[i].qs[j] >> 4) - 8;
|
||||
for (int j = 0; j < qk/4; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x000F) - 8;
|
||||
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8;
|
||||
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8;
|
||||
const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8;
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d;
|
||||
y[i*qk + j + qk/2] = x1*d;
|
||||
y[i*qk + 2 * j + 0 ] = x0*d;
|
||||
y[i*qk + 2 * j + qk/2 ] = x1*d;
|
||||
y[i*qk + 2 * j + 1 ] = x2*d;
|
||||
y[i*qk + 2 * j + 1 + qk/2] = x3*d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -49,12 +53,16 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
|
|||
const half d = x[i].d;
|
||||
const half m = x[i].m;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x0F);
|
||||
const int x1 = (x[i].qs[j] >> 4);
|
||||
for (int j = 0; j < qk/4; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x000F);
|
||||
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4);
|
||||
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8);
|
||||
const int x3 = ((x[i].qs[j] & 0xF000) >> 12);
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d + m;
|
||||
y[i*qk + j + qk/2] = x1*d + m;
|
||||
y[i*qk + 2 * j + 0 ] = x0*d + m;
|
||||
y[i*qk + 2 * j + qk/2 ] = x1*d + m;
|
||||
y[i*qk + 2 * j + 1 ] = x2*d + m;
|
||||
y[i*qk + 2 * j + 1 + qk/2] = x3*d + m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -401,6 +409,10 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
}
|
||||
sumy *= (-8.f);
|
||||
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
||||
for (int i = 0; i < 32; i++) {
|
||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
// prefetch next x block
|
||||
|
@ -409,8 +421,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||
// calculate
|
||||
float d = qb_curr.d;
|
||||
float acc = sumy;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||
}
|
||||
sumf[row] += d * acc;
|
||||
qb_curr = qb_next;
|
||||
|
@ -432,6 +445,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
}
|
||||
sumy *= (-8.f);
|
||||
for (int i = 0; i < 32; i++) {
|
||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
// prefetch next x block
|
||||
|
@ -440,8 +456,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||
// calculate
|
||||
float d = qb_curr.d;
|
||||
float acc = sumy;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||
}
|
||||
if (tiisg < nb % N_SIMDWIDTH) {
|
||||
sumf[row] += d * acc;
|
||||
|
@ -487,6 +504,10 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
}
|
||||
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
||||
for (int i = 0; i < 32; i++) {
|
||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
// prefetch next x block
|
||||
|
@ -496,8 +517,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||
const float d = qb_curr.d;
|
||||
const float m = qb_curr.m;
|
||||
float acc = 0.f;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||
}
|
||||
sumf[row] += d * acc + m * sumy;
|
||||
qb_curr = qb_next;
|
||||
|
@ -518,6 +540,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
}
|
||||
for (int i = 0; i < 32; i++) {
|
||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
// prefetch next x block
|
||||
|
@ -527,8 +552,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||
const float d = qb_curr.d;
|
||||
const float m = qb_curr.m;
|
||||
float acc = 0.f;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||
}
|
||||
if (tiisg < nb % N_SIMDWIDTH) {
|
||||
sumf[row] += d * acc + m * sumy;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue