metal : still optimizing Q4_K

This commit pushes it down to 25.3 ms / token.

The crazy idea of using 6 bits for the scales is really costly on
Metal: if I remove the bit fiddling necessary to make the block
scales, time goes almost to the Q4_0 23 ms/token.

Before pushing the k-quants upstream I had a Q4_K variant that
had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight,
was running slightly slower on the CPU (due to the larger model size
and being memory bound there), and the difference was entirely
negligible under CUDA. So, I decided to publish the version with 6-bit
scales. Perhaps I should re-consider and change to 8-bit scales?
This commit is contained in:
Iwan Kawrakow 2023-06-10 15:57:33 +03:00
parent 5e2f67fe00
commit fff0e4f9d8

View file

@ -986,42 +986,82 @@ kernel void kernel_mul_mat_q4_k_f32(
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid%4; // 0...3
//const int ir = tid%4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
sum[ith] = 0.0f;
//uint16_t aux_scales[4];
//thread uint8_t * sc = (thread uint8_t *)aux_scales;
//uint32_t aux32[4];
//thread const uint8_t * sc = (thread const uint8_t *)aux32;
uchar2 sc1, sc2, sc3, sc4;
float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q1 = (x + i)->qs + 32*im + l0;
device const float * y1 = yy + i*QK_K + 64*im + l0;
device const uint8_t * q1 = (x + i)->qs + q_offset;
device const uint8_t * q2 = q1 + 64;
device const float * y1 = yy + i*QK_K + y_offset;
device const float * y2 = y1 + 128;
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
const float dall = (float)((x + i)->d);
const float dmin = (float)((x + i)->dmin);
const uchar2 sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
const uchar2 sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
const uchar2 sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
const uchar2 sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
//device const uint32_t * a = (device const uint32_t *)(x + i)->scales;
//aux32[0] = a[0] & 0x3f3f3f3f; // scales for 0, 32, 64, 96
//aux32[1] = a[1] & 0x3f3f3f3f; // mins for 0, 32, 64, 96
//aux32[2] = ((a[2] >> 0) & 0x0f0f0f0f) | ((a[0] & 0xc0c0c0c0) >> 2); // scales for 128, 160, 192, 224
//aux32[3] = ((a[2] >> 4) & 0x0f0f0f0f) | ((a[1] & 0xc0c0c0c0) >> 2); // mins for 128, 160, 192, 224
float2 s = {0.f, 0.f};
//aux_scales[0] = (uint16_t)(a[im+0] & kmask1);
//aux_scales[1] = (uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2));
//aux_scales[2] = (uint16_t)(a[im+2] & kmask1);
//aux_scales[3] = (uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2));
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
//float2 s = {0.f, 0.f};
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4)
+ y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4);
s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
////s[0] += y1[l] * sc[0] * (q1[l] & 0xF) + y1[l+32] * sc[1] * (q1[l] >> 4)
//// + y2[l] * sc[2] * (q2[l] & 0xF) + y2[l+32] * sc[3] * (q2[l] >> 4);
////s[1] += y1[l] * sc[4] + y1[l+32] * sc[5] + y2[l] * sc[6] + y2[l+32] * sc[7];
////s[0] += y1[l] * sc[2*im+0] * (q1[l] & 0xF) + y1[l+32] * sc[2*im+1] * (q1[l] >> 4)
//// + y2[l] * sc[2*im+8] * (q2[l] & 0xF) + y2[l+32] * sc[2*im+9] * (q2[l] >> 4);
////s[1] += y1[l] * sc[2*im+4] + y1[l+32] * sc[2*im+5] + y2[l] * sc[2*im+12] + y2[l+32] * sc[2*im+13];
//s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4)
// + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4);
//s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
}
sumf += dall * s[0] - dmin * s[1];
//sumf += dall * s[0] - dmin * s[1];
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
}
sum[ith] = sumf;
//