metal: very slightly faster TG for Q4_K
This commit is contained in:
parent
d90b5981d0
commit
b42dfdcd89
1 changed files with 23 additions and 20 deletions
|
@ -733,7 +733,10 @@ kernel void kernel_mul_mat_f16_f32_old(
|
|||
}
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
float all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0 && rb + row < ne11) {
|
||||
if (rb + row >= ne11) {
|
||||
break;
|
||||
}
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
|
@ -1545,8 +1548,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||
const uint offset0 = r2/gqa*(nb*ne0);
|
||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float4 yl[8];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int step = sizeof(block_q4_K) * nb / 2;
|
||||
|
@ -1556,14 +1558,14 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||
uint16_t sc16[4];
|
||||
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
||||
|
||||
const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f};
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 sumy = {0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
||||
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
||||
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
||||
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
||||
yl[i] = {y4[i], y4[i+32], y4[i+128], y4[i+160]};
|
||||
sumy += yl[i];
|
||||
}
|
||||
|
||||
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
||||
|
@ -1582,22 +1584,23 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
||||
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
||||
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
||||
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
||||
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
||||
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
||||
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
||||
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
||||
|
||||
acc1[0] += yl[i+0][0] * (q1[i/2] & 0x000F);
|
||||
acc1[1] += yl[i+0][1] * (q1[i/2] & 0x00F0);
|
||||
acc1[2] += yl[i+0][2] * (q2[i/2] & 0x000F);
|
||||
acc1[3] += yl[i+0][3] * (q2[i/2] & 0x00F0);
|
||||
acc2[0] += yl[i+1][0] * (q1[i/2] & 0x0F00);
|
||||
acc2[1] += yl[i+1][1] * (q1[i/2] & 0xF000);
|
||||
acc2[2] += yl[i+1][2] * (q2[i/2] & 0x0F00);
|
||||
acc2[3] += yl[i+1][3] * (q2[i/2] & 0xF000);
|
||||
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
||||
acc1 += acc2 / 256.f;
|
||||
acc1 *= norm;
|
||||
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
q1 += step;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue