metal: very slightly faster TG for Q4_K

This commit is contained in:
Iwan Kawrakow 2023-09-11 08:14:45 +02:00
parent d90b5981d0
commit b42dfdcd89

View file

@ -733,7 +733,10 @@ kernel void kernel_mul_mat_f16_f32_old(
}
for (int row = 0; row < N_F16_F32; ++row) {
float all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && rb + row < ne11) {
if (rb + row >= ne11) {
break;
}
if (tiisg == 0) {
dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum;
}
}
@ -1545,8 +1548,7 @@ kernel void kernel_mul_mat_q4_K_f32(
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
float yh[16];
float4 yl[8];
float sumf[N_DST]={0.f}, all_sum;
const int step = sizeof(block_q4_K) * nb / 2;
@ -1556,14 +1558,14 @@ kernel void kernel_mul_mat_q4_K_f32(
uint16_t sc16[4];
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f};
for (int ib = ix; ib < nb; ib += 4) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
float4 sumy = {0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
yl[i] = {y4[i], y4[i+32], y4[i+128], y4[i+160]};
sumy += yl[i];
}
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
@ -1582,22 +1584,23 @@ kernel void kernel_mul_mat_q4_K_f32(
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
acc1[0] += yl[i+0][0] * (q1[i/2] & 0x000F);
acc1[1] += yl[i+0][1] * (q1[i/2] & 0x00F0);
acc1[2] += yl[i+0][2] * (q2[i/2] & 0x000F);
acc1[3] += yl[i+0][3] * (q2[i/2] & 0x00F0);
acc2[0] += yl[i+1][0] * (q1[i/2] & 0x0F00);
acc2[1] += yl[i+1][1] * (q1[i/2] & 0xF000);
acc2[2] += yl[i+1][2] * (q2[i/2] & 0x0F00);
acc2[3] += yl[i+1][3] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
acc1 += acc2 / 256.f;
acc1 *= norm;
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;