Slowly progressing on Q3_K on metal
We are now 13% faster than master
This commit is contained in:
parent
123a870b36
commit
9eb1d4d347
1 changed files with 36 additions and 17 deletions
|
@ -1153,15 +1153,15 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
const uint16_t m1 = 1 << (4*ip + il);
|
const uint16_t m1 = 1 << (4*ip + il);
|
||||||
const uint16_t m2 = m1 << 8;
|
const uint16_t m2 = m1 << 8;
|
||||||
|
//const uint16_t m1 = ip == 0 ? il == 0 ? 0x0001 : 0x0004 : il == 0 ? 0x0010 : 0x0040;
|
||||||
|
//const uint16_t m2 = ip == 0 ? il == 0 ? 0x0100 : 0x0400 : il == 0 ? 0x1000 : 0x4000;
|
||||||
const uint16_t m3 = m1 << 1;
|
const uint16_t m3 = m1 << 1;
|
||||||
const uint16_t m4 = m2 << 1;
|
const uint16_t m4 = m2 << 1;
|
||||||
|
|
||||||
|
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
||||||
|
|
||||||
const int shift = 2*il;
|
const int shift = 2*il;
|
||||||
const int32_t qm1 = 0x0003 << shift;
|
const float v1 = il == 0 ? 4.f : 64.f;
|
||||||
const int32_t qm2 = 0x0300 << shift;
|
|
||||||
const int32_t qm3 = qm1 << 2;
|
|
||||||
const int32_t qm4 = qm2 << 2;
|
|
||||||
const float v1 = 4 << shift;
|
|
||||||
const float v2 = 4.f * v1;
|
const float v2 = 4.f * v1;
|
||||||
|
|
||||||
const uint16_t s_shift1 = 4*ip;
|
const uint16_t s_shift1 = 4*ip;
|
||||||
|
@ -1174,8 +1174,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
uint16_t scales16[2];
|
uint32_t scales32, aux32;
|
||||||
thread const int8_t * scales = (thread const int8_t *)scales16;
|
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
||||||
|
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
||||||
|
|
||||||
|
//uint16_t scales16[2];
|
||||||
|
//thread const int8_t * scales = (thread const int8_t *)scales16;
|
||||||
|
|
||||||
float sumf1[2] = {0.f};
|
float sumf1[2] = {0.f};
|
||||||
float sumf2[2] = {0.f};
|
float sumf2[2] = {0.f};
|
||||||
|
@ -1196,17 +1200,28 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
const float d_all = (float)dh[0];
|
const float d_all = (float)dh[0];
|
||||||
scales16[0] = ((a[il+0] >> s_shift1) & kmask2) | (((a[4] >> s_shift2) << 4) & kmask1);
|
|
||||||
scales16[1] = ((a[il+1] >> s_shift1) & kmask2) | (((a[5] >> s_shift2) << 4) & kmask1);
|
scales16[0] = a[4];
|
||||||
|
scales16[1] = a[5];
|
||||||
|
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
||||||
|
scales16[0] = a[il+0];
|
||||||
|
scales16[1] = a[il+1];
|
||||||
|
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
||||||
|
//scales16[0] = ((a[il+0] >> s_shift1) & kmask2) | (((a[4] >> s_shift2) << 4) & kmask1);
|
||||||
|
//scales16[1] = ((a[il+1] >> s_shift1) & kmask2) | (((a[5] >> s_shift2) << 4) & kmask1);
|
||||||
|
|
||||||
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const int32_t qs = q[l/2];
|
const int32_t qs = q[l/2];
|
||||||
s1 += yl[l+0] * (qs & qm1);
|
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
||||||
s2 += yl[l+1] * (qs & qm2);
|
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
||||||
|
//s1 += yl[l+0] * (qs & qm1);
|
||||||
|
//s2 += yl[l+1] * (qs & qm2);
|
||||||
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
|
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
|
||||||
s4 += yl[l+16] * (qs & qm3);
|
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
||||||
s5 += yl[l+17] * (qs & qm4);
|
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
||||||
|
//s4 += yl[l+16] * (qs & qm3);
|
||||||
|
//s5 += yl[l+17] * (qs & qm4);
|
||||||
s6 += ((h[l/2] & m3) ? 0.f : yl[l+16]) + ((h[l/2] & m4) ? 0.f : yl[l+17]);
|
s6 += ((h[l/2] & m3) ? 0.f : yl[l+16]) + ((h[l/2] & m4) ? 0.f : yl[l+17]);
|
||||||
}
|
}
|
||||||
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
|
@ -1217,11 +1232,15 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const int32_t qs = q[l/2+8];
|
const int32_t qs = q[l/2+8];
|
||||||
s1 += yl[l+8] * (qs & qm1);
|
//s1 += yl[l+8] * (qs & qm1);
|
||||||
s2 += yl[l+9] * (qs & qm2);
|
//s2 += yl[l+9] * (qs & qm2);
|
||||||
|
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
||||||
|
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
||||||
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]);
|
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]);
|
||||||
s4 += yl[l+24] * (qs & qm3);
|
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
||||||
s5 += yl[l+25] * (qs & qm4);
|
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
||||||
|
//s4 += yl[l+24] * (qs & qm3);
|
||||||
|
//s5 += yl[l+25] * (qs & qm4);
|
||||||
s6 += ((h[l/2+8] & m3) ? 0.f : yl[l+24]) + ((h[l/2+8] & m4) ? 0.f : yl[l+25]);
|
s6 += ((h[l/2+8] & m3) ? 0.f : yl[l+24]) + ((h[l/2+8] & m4) ? 0.f : yl[l+25]);
|
||||||
}
|
}
|
||||||
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue