nother small improvement for Q3_K on metal

This commit is contained in:
Iwan Kawrakow 2023-09-03 21:50:26 +03:00
parent 9eb1d4d347
commit 2cab21c3db

View file

@ -1151,15 +1151,21 @@ kernel void kernel_mul_mat_q3_K_f32(
const int n = 8; const int n = 8;
const int l0 = n*ir; const int l0 = n*ir;
const uint16_t m1 = 1 << (4*ip + il); // One would think that the Metal compiler would figure out that ip and il can only have
const uint16_t m2 = m1 << 8; // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
//const uint16_t m1 = ip == 0 ? il == 0 ? 0x0001 : 0x0004 : il == 0 ? 0x0010 : 0x0040; // with these two tales.
//const uint16_t m2 = ip == 0 ? il == 0 ? 0x0100 : 0x0400 : il == 0 ? 0x1000 : 0x4000; //
const uint16_t m3 = m1 << 1; // Possible masks for the high bit
const uint16_t m4 = m2 << 1; const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
// Possible masks for the low 2 bits
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il; const int shift = 2*il;
const float v1 = il == 0 ? 4.f : 64.f; const float v1 = il == 0 ? 4.f : 64.f;
const float v2 = 4.f * v1; const float v2 = 4.f * v1;
@ -1178,9 +1184,6 @@ kernel void kernel_mul_mat_q3_K_f32(
thread uint16_t * scales16 = (thread uint16_t *)&scales32; thread uint16_t * scales16 = (thread uint16_t *)&scales32;
thread const int8_t * scales = (thread const int8_t *)&scales32; thread const int8_t * scales = (thread const int8_t *)&scales32;
//uint16_t scales16[2];
//thread const int8_t * scales = (thread const int8_t *)scales16;
float sumf1[2] = {0.f}; float sumf1[2] = {0.f};
float sumf2[2] = {0.f}; float sumf2[2] = {0.f};
for (int i = ix; i < nb; i += 4) { for (int i = ix; i < nb; i += 4) {
@ -1207,22 +1210,16 @@ kernel void kernel_mul_mat_q3_K_f32(
scales16[0] = a[il+0]; scales16[0] = a[il+0];
scales16[1] = a[il+1]; scales16[1] = a[il+1];
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
//scales16[0] = ((a[il+0] >> s_shift1) & kmask2) | (((a[4] >> s_shift2) << 4) & kmask1);
//scales16[1] = ((a[il+1] >> s_shift1) & kmask2) | (((a[5] >> s_shift2) << 4) & kmask1);
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const int32_t qs = q[l/2]; const int32_t qs = q[l/2];
s1 += yl[l+0] * (qs & qm[il/2][0]); s1 += yl[l+0] * (qs & qm[il/2][0]);
s2 += yl[l+1] * (qs & qm[il/2][1]); s2 += yl[l+1] * (qs & qm[il/2][1]);
//s1 += yl[l+0] * (qs & qm1); s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
//s2 += yl[l+1] * (qs & qm2);
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
s4 += yl[l+16] * (qs & qm[il/2][2]); s4 += yl[l+16] * (qs & qm[il/2][2]);
s5 += yl[l+17] * (qs & qm[il/2][3]); s5 += yl[l+17] * (qs & qm[il/2][3]);
//s4 += yl[l+16] * (qs & qm3); s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
//s5 += yl[l+17] * (qs & qm4);
s6 += ((h[l/2] & m3) ? 0.f : yl[l+16]) + ((h[l/2] & m4) ? 0.f : yl[l+17]);
} }
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
@ -1232,16 +1229,12 @@ kernel void kernel_mul_mat_q3_K_f32(
s1 = s2 = s3 = s4 = s5 = s6 = 0; s1 = s2 = s3 = s4 = s5 = s6 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const int32_t qs = q[l/2+8]; const int32_t qs = q[l/2+8];
//s1 += yl[l+8] * (qs & qm1);
//s2 += yl[l+9] * (qs & qm2);
s1 += yl[l+8] * (qs & qm[il/2][0]); s1 += yl[l+8] * (qs & qm[il/2][0]);
s2 += yl[l+9] * (qs & qm[il/2][1]); s2 += yl[l+9] * (qs & qm[il/2][1]);
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]); s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
s4 += yl[l+24] * (qs & qm[il/2][2]); s4 += yl[l+24] * (qs & qm[il/2][2]);
s5 += yl[l+25] * (qs & qm[il/2][3]); s5 += yl[l+25] * (qs & qm[il/2][3]);
//s4 += yl[l+24] * (qs & qm3); s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
//s5 += yl[l+25] * (qs & qm4);
s6 += ((h[l/2+8] & m3) ? 0.f : yl[l+24]) + ((h[l/2+8] & m4) ? 0.f : yl[l+25]);
} }
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
@ -1261,11 +1254,14 @@ kernel void kernel_mul_mat_q3_K_f32(
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf); sumf1[row] = simd_sum(sumf);
if (tiisg == 0) { }
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
} }
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mat_q3_K_f32(