3-5% faster Q4_0 on Metal

This commit is contained in:
Iwan Kawrakow 2023-07-13 10:32:19 +02:00
parent 1cbf561466
commit 585ac35b42

View file

@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// each thread in a SIMD group deals with 1 block. // each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) { for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) { for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
} }
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < N_DST; row++) {
// prefetch next x block // prefetch next x block
@ -405,19 +408,30 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate // calculate
float d = qb_curr.d; float d = qb_curr.d;
float2 acc = {0.0f, 0.0f}; float acc = sumy;
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
} }
sumf[row] += d * (acc[0] - 8.f*acc[1]); sumf[row] += d * acc;
qb_curr = qb_next; qb_curr = qb_next;
} }
} }
if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
} else {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) { for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
} }
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < N_DST; row++) {
// prefetch next x block // prefetch next x block
@ -425,13 +439,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate // calculate
float d = qb_curr.d; float d = qb_curr.d;
float2 acc = {0.0f, 0.0f}; float acc = sumy;
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
} }
if (tiisg < nb % N_SIMDWIDTH) { if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * (acc[0] - 8.f*acc[1]); sumf[row] += d * acc;
} }
qb_curr = qb_next; qb_curr = qb_next;
@ -440,6 +453,7 @@ kernel void kernel_mul_mat_q4_0_f32(
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
} }
}
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mat_q4_1_f32(