Much better Q6_K implementation for metal

28.3 ms / token for 7B. Subtracting ~9 ms that is spent in
other compute graph operations, we are left with ~19 ms
for the matrix multiplications. The model is ~5.5 GB,
so we are getting 1000 / 19 * 5.5 = 290 GB/s!
This commit is contained in:
Iwan Kawrakow 2023-06-08 18:59:49 +03:00
parent 4a82c8d45c
commit c1be42404a

View file

@ -782,37 +782,29 @@ kernel void kernel_mul_mat_q6_k_f32(
const int iqs = step * tpitg.y; // 0...240 in steps of 16 const int iqs = step * tpitg.y; // 0...240 in steps of 16
const int ip = iqs / 128; // 0 or 1 const int ip = iqs / 128; // 0 or 1
const int il = (iqs - 128*ip)/16; // 0...7 const int il = (iqs - 128*ip)/16; // 0...7
const int is = 8*ip; const int n = 4;
const int is = 8*ip + (n*il)/16;
float sumf = 0; float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) { for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * ql = x[i].ql + 64*ip + il; device const uint8_t * ql = x[i].ql + 64*ip + n*il;
device const uint8_t * qh = x[i].qh + 32*ip + il; device const uint8_t * qh = x[i].qh + 32*ip + n*il;
device const int8_t * sc = x[i].scales + is; device const int8_t * sc = x[i].scales + is;
device const float * y = yy + i * QK_K + 128*ip + il; device const float * y = yy + i * QK_K + 128*ip + n*il;
const float dall = x[i].d; const float dall = x[i].d;
float result = sc[0] * y[ 0] * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & kmask1) << 4)) - 32) float4 sums = {0.f, 0.f, 0.f, 0.f};
+ sc[0] * y[ 8] * ((int8_t)((ql[ 8] & 0xF) | ((qh[ 8] & kmask1) << 4)) - 32) for (int l = 0; l < n; ++l) {
+ sc[1] * y[ 16] * ((int8_t)((ql[16] & 0xF) | ((qh[16] & kmask1) << 4)) - 32) sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sc[1] * y[ 24] * ((int8_t)((ql[24] & 0xF) | ((qh[24] & kmask1) << 4)) - 32) sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sc[2] * y[ 32] * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & kmask2) << 2)) - 32) sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sc[2] * y[ 40] * ((int8_t)((ql[40] & 0xF) | ((qh[ 8] & kmask2) << 2)) - 32) sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ sc[3] * y[ 48] * ((int8_t)((ql[48] & 0xF) | ((qh[16] & kmask2) << 2)) - 32) }
+ sc[3] * y[ 56] * ((int8_t)((ql[56] & 0xF) | ((qh[24] & kmask2) << 2)) - 32)
+ sc[4] * y[ 64] * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & kmask3) << 0)) - 32)
+ sc[4] * y[ 72] * ((int8_t)((ql[ 8] >> 4) | ((qh[ 8] & kmask3) << 0)) - 32)
+ sc[5] * y[ 80] * ((int8_t)((ql[16] >> 4) | ((qh[16] & kmask3) << 0)) - 32)
+ sc[5] * y[ 88] * ((int8_t)((ql[24] >> 4) | ((qh[24] & kmask3) << 0)) - 32)
+ sc[6] * y[ 96] * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & kmask4) >> 2)) - 32)
+ sc[6] * y[104] * ((int8_t)((ql[40] >> 4) | ((qh[ 8] & kmask4) >> 2)) - 32)
+ sc[7] * y[112] * ((int8_t)((ql[48] >> 4) | ((qh[16] & kmask4) >> 2)) - 32)
+ sc[7] * y[120] * ((int8_t)((ql[56] >> 4) | ((qh[24] & kmask4) >> 2)) - 32);
sumf += dall * result; sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
} }