metal: use uint16_t instead of uint8_t.
Apple GPU doesn't like uint8_t. For every operation on uint8_t the gpu need to copy the uint8_t to an empty 16 bit register, then it can issue other instructions. For the matrix-vector multiplication kernel only, we observed a 340~350 GB/s memory read speed on M1 Max after this commit, which is very close to the reported hardware limit.
This commit is contained in:
parent
a6803cab94
commit
bbce392890
1 changed files with 46 additions and 20 deletions
|
@ -8,14 +8,14 @@ using namespace metal;
|
||||||
#define QR4_0 2
|
#define QR4_0 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
uint16_t qs[QK4_0 / 4]; // nibbles / quants
|
||||||
} block_q4_0;
|
} block_q4_0;
|
||||||
|
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
uint16_t qs[QK4_1 / 4]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
||||||
|
@ -28,12 +28,16 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
const half d = x[i].d;
|
const half d = x[i].d;
|
||||||
|
|
||||||
for (int j = 0; j < qk/2; ++j) {
|
for (int j = 0; j < qk/4; ++j) {
|
||||||
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
const int x0 = (x[i].qs[j] & 0x000F) - 8;
|
||||||
const int x1 = (x[i].qs[j] >> 4) - 8;
|
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8;
|
||||||
|
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8;
|
||||||
|
const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8;
|
||||||
|
|
||||||
y[i*qk + j + 0 ] = x0*d;
|
y[i*qk + 2 * j + 0 ] = x0*d;
|
||||||
y[i*qk + j + qk/2] = x1*d;
|
y[i*qk + 2 * j + qk/2 ] = x1*d;
|
||||||
|
y[i*qk + 2 * j + 1 ] = x2*d;
|
||||||
|
y[i*qk + 2 * j + 1 + qk/2] = x3*d;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,12 +53,16 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
|
||||||
const half d = x[i].d;
|
const half d = x[i].d;
|
||||||
const half m = x[i].m;
|
const half m = x[i].m;
|
||||||
|
|
||||||
for (int j = 0; j < qk/2; ++j) {
|
for (int j = 0; j < qk/4; ++j) {
|
||||||
const int x0 = (x[i].qs[j] & 0x0F);
|
const int x0 = (x[i].qs[j] & 0x000F);
|
||||||
const int x1 = (x[i].qs[j] >> 4);
|
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4);
|
||||||
|
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8);
|
||||||
|
const int x3 = ((x[i].qs[j] & 0xF000) >> 12);
|
||||||
|
|
||||||
y[i*qk + j + 0 ] = x0*d + m;
|
y[i*qk + 2 * j + 0 ] = x0*d + m;
|
||||||
y[i*qk + j + qk/2] = x1*d + m;
|
y[i*qk + 2 * j + qk/2 ] = x1*d + m;
|
||||||
|
y[i*qk + 2 * j + 1 ] = x2*d + m;
|
||||||
|
y[i*qk + 2 * j + 1 + qk/2] = x3*d + m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -401,6 +409,10 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
sumy *= (-8.f);
|
sumy *= (-8.f);
|
||||||
|
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||||
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
|
@ -409,8 +421,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
// calculate
|
// calculate
|
||||||
float d = qb_curr.d;
|
float d = qb_curr.d;
|
||||||
float acc = sumy;
|
float acc = sumy;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i+=2) {
|
||||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||||
|
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
sumf[row] += d * acc;
|
sumf[row] += d * acc;
|
||||||
qb_curr = qb_next;
|
qb_curr = qb_next;
|
||||||
|
@ -432,6 +445,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
sumy *= (-8.f);
|
sumy *= (-8.f);
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||||
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
|
@ -440,8 +456,9 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
// calculate
|
// calculate
|
||||||
float d = qb_curr.d;
|
float d = qb_curr.d;
|
||||||
float acc = sumy;
|
float acc = sumy;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i+=2) {
|
||||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||||
|
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
if (tiisg < nb % N_SIMDWIDTH) {
|
if (tiisg < nb % N_SIMDWIDTH) {
|
||||||
sumf[row] += d * acc;
|
sumf[row] += d * acc;
|
||||||
|
@ -487,6 +504,10 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
|
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||||
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
|
@ -496,8 +517,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
const float d = qb_curr.d;
|
const float d = qb_curr.d;
|
||||||
const float m = qb_curr.m;
|
const float m = qb_curr.m;
|
||||||
float acc = 0.f;
|
float acc = 0.f;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i+=2) {
|
||||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||||
|
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
sumf[row] += d * acc + m * sumy;
|
sumf[row] += d * acc + m * sumy;
|
||||||
qb_curr = qb_next;
|
qb_curr = qb_next;
|
||||||
|
@ -518,6 +540,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
||||||
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
|
@ -527,8 +552,9 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
const float d = qb_curr.d;
|
const float d = qb_curr.d;
|
||||||
const float m = qb_curr.m;
|
const float m = qb_curr.m;
|
||||||
float acc = 0.f;
|
float acc = 0.f;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i+=2) {
|
||||||
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
||||||
|
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
if (tiisg < nb % N_SIMDWIDTH) {
|
if (tiisg < nb % N_SIMDWIDTH) {
|
||||||
sumf[row] += d * acc + m * sumy;
|
sumf[row] += d * acc + m * sumy;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue