Another Q3_K speedup on metal
Combined with previous commit, we are now +9.6% for TG. PP is not affected as this happens via the matrix multiplication templates.
This commit is contained in:
parent
ec13de521c
commit
123a870b36
1 changed files with 40 additions and 22 deletions
|
@ -1138,30 +1138,34 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
||||||
|
|
||||||
float yl[16];
|
float yl[32];
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x3030;
|
const uint16_t kmask1 = 0x3030;
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
const int tid = tiisg/2;
|
const int tid = tiisg/4;
|
||||||
const int ix = tiisg%2;
|
const int ix = tiisg%4;
|
||||||
const int ip = tid/8; // 0 or 1
|
const int ip = tid/4; // 0 or 1
|
||||||
const int il = tid/2 - 4*ip; // 0...3
|
const int il = 2*((tid%4)/2); // 0 or 2
|
||||||
const int ir = tid%2;
|
const int ir = tid%2;
|
||||||
const int n = 8;
|
const int n = 8;
|
||||||
const int l0 = n*ir;
|
const int l0 = n*ir;
|
||||||
|
|
||||||
const uint16_t m1 = 1 << (4*ip + il);
|
const uint16_t m1 = 1 << (4*ip + il);
|
||||||
const uint16_t m2 = m1 << 8;
|
const uint16_t m2 = m1 << 8;
|
||||||
|
const uint16_t m3 = m1 << 1;
|
||||||
|
const uint16_t m4 = m2 << 1;
|
||||||
|
|
||||||
const int shift = 2*il;
|
const int shift = 2*il;
|
||||||
const int32_t qm1 = 0x0003 << shift;
|
const int32_t qm1 = 0x0003 << shift;
|
||||||
const int32_t qm2 = 0x0300 << shift;
|
const int32_t qm2 = 0x0300 << shift;
|
||||||
|
const int32_t qm3 = qm1 << 2;
|
||||||
|
const int32_t qm4 = qm2 << 2;
|
||||||
const float v1 = 4 << shift;
|
const float v1 = 4 << shift;
|
||||||
|
const float v2 = 4.f * v1;
|
||||||
|
|
||||||
const uint16_t s_shift1 = 4*ip;
|
const uint16_t s_shift1 = 4*ip;
|
||||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
const uint16_t s_shift2 = s_shift1 + il;
|
||||||
const int ik = 4 + (il%2);
|
|
||||||
|
|
||||||
const int q_offset = 32*ip + l0;
|
const int q_offset = 32*ip + l0;
|
||||||
const int y_offset = 128*ip + 32*il + l0;
|
const int y_offset = 128*ip + 32*il + l0;
|
||||||
|
@ -1170,15 +1174,18 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
uint16_t scales16;
|
uint16_t scales16[2];
|
||||||
thread const int8_t * scales = (thread const int8_t *)&scales16;
|
thread const int8_t * scales = (thread const int8_t *)scales16;
|
||||||
|
|
||||||
float sumf1[2] = {0.f};
|
float sumf1[2] = {0.f};
|
||||||
for (int i = ix; i < nb; i += 2) {
|
float sumf2[2] = {0.f};
|
||||||
|
for (int i = ix; i < nb; i += 4) {
|
||||||
|
|
||||||
for (int l = 0; l < 8; ++l) {
|
for (int l = 0; l < 8; ++l) {
|
||||||
yl[l+ 0] = y1[l+ 0];
|
yl[l+ 0] = y1[l+ 0];
|
||||||
yl[l+ 8] = y1[l+16];
|
yl[l+ 8] = y1[l+16];
|
||||||
|
yl[l+16] = y1[l+32];
|
||||||
|
yl[l+24] = y1[l+48];
|
||||||
}
|
}
|
||||||
|
|
||||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||||
|
@ -1189,27 +1196,38 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
const float d_all = (float)dh[0];
|
const float d_all = (float)dh[0];
|
||||||
scales16 = ((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) << 4) & kmask1);
|
scales16[0] = ((a[il+0] >> s_shift1) & kmask2) | (((a[4] >> s_shift2) << 4) & kmask1);
|
||||||
|
scales16[1] = ((a[il+1] >> s_shift1) & kmask2) | (((a[5] >> s_shift2) << 4) & kmask1);
|
||||||
|
|
||||||
float s1 = 0, s2 = 0, s3 = 0;
|
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const int32_t qs = q[l/2];
|
const int32_t qs = q[l/2];
|
||||||
s1 += yl[l+0] * (qs & qm1);
|
s1 += yl[l+0] * (qs & qm1);
|
||||||
s2 += yl[l+1] * (qs & qm2);
|
s2 += yl[l+1] * (qs & qm2);
|
||||||
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
|
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
|
||||||
|
s4 += yl[l+16] * (qs & qm3);
|
||||||
|
s5 += yl[l+17] * (qs & qm4);
|
||||||
|
s6 += ((h[l/2] & m3) ? 0.f : yl[l+16]) + ((h[l/2] & m4) ? 0.f : yl[l+17]);
|
||||||
}
|
}
|
||||||
float d = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
sumf1[row] += d * (scales[0] - 32);
|
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||||
|
sumf1[row] += d1 * (scales[0] - 32);
|
||||||
|
sumf2[row] += d2 * (scales[2] - 32);
|
||||||
|
|
||||||
s1 = s2 = s3 = 0;
|
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const int32_t qs = q[l/2+8];
|
const int32_t qs = q[l/2+8];
|
||||||
s1 += yl[l+8] * (qs & qm1);
|
s1 += yl[l+8] * (qs & qm1);
|
||||||
s2 += yl[l+9] * (qs & qm2);
|
s2 += yl[l+9] * (qs & qm2);
|
||||||
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]);
|
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]);
|
||||||
|
s4 += yl[l+24] * (qs & qm3);
|
||||||
|
s5 += yl[l+25] * (qs & qm4);
|
||||||
|
s6 += ((h[l/2+8] & m3) ? 0.f : yl[l+24]) + ((h[l/2+8] & m4) ? 0.f : yl[l+25]);
|
||||||
}
|
}
|
||||||
d = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
sumf1[row] += d * (scales[1] - 32);
|
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||||
|
sumf1[row] += d1 * (scales[1] - 32);
|
||||||
|
sumf2[row] += d2 * (scales[3] - 32);
|
||||||
|
|
||||||
q += step;
|
q += step;
|
||||||
h += step;
|
h += step;
|
||||||
|
@ -1218,12 +1236,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
y1 += 2 * QK_K;
|
y1 += 4 * QK_K;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
const float sumf = sumf1[row] / (1 << shift);
|
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
||||||
const float tot = simd_sum(sumf);
|
const float tot = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue