3-5% faster Q4_0 on Metal

This commit is contained in:
Iwan Kawrakow 2023-07-13 10:32:19 +02:00
parent 1cbf561466
commit 585ac35b42

View file

@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -405,19 +408,30 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
float2 acc = {0.0f, 0.0f};
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
sumf[row] += d * (acc[0] - 8.f*acc[1]);
sumf[row] += d * acc;
qb_curr = qb_next;
}
}
if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
} else {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@ -425,13 +439,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
float2 acc = {0.0f, 0.0f};
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * (acc[0] - 8.f*acc[1]);
sumf[row] += d * acc;
}
qb_curr = qb_next;
@ -441,6 +454,7 @@ kernel void kernel_mul_mat_q4_0_f32(
}
}
}
}
kernel void kernel_mul_mat_q4_1_f32(
device const void * src0,