Much better Q6_K implementation for metal
28.3 ms / token for 7B. Subtracting ~9 ms that is spent in other compute graph operations, we are left with ~19 ms for the matrix multiplications. The model is ~5.5 GB, so we are getting 1000 / 19 * 5.5 = 290 GB/s!
This commit is contained in:
parent
4a82c8d45c
commit
c1be42404a
1 changed files with 13 additions and 21 deletions
|
@ -782,37 +782,29 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|||
const int iqs = step * tpitg.y; // 0...240 in steps of 16
|
||||
const int ip = iqs / 128; // 0 or 1
|
||||
const int il = (iqs - 128*ip)/16; // 0...7
|
||||
const int is = 8*ip;
|
||||
const int n = 4;
|
||||
const int is = 8*ip + (n*il)/16;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
|
||||
device const uint8_t * ql = x[i].ql + 64*ip + il;
|
||||
device const uint8_t * qh = x[i].qh + 32*ip + il;
|
||||
device const uint8_t * ql = x[i].ql + 64*ip + n*il;
|
||||
device const uint8_t * qh = x[i].qh + 32*ip + n*il;
|
||||
device const int8_t * sc = x[i].scales + is;
|
||||
|
||||
device const float * y = yy + i * QK_K + 128*ip + il;
|
||||
device const float * y = yy + i * QK_K + 128*ip + n*il;
|
||||
|
||||
const float dall = x[i].d;
|
||||
|
||||
float result = sc[0] * y[ 0] * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & kmask1) << 4)) - 32)
|
||||
+ sc[0] * y[ 8] * ((int8_t)((ql[ 8] & 0xF) | ((qh[ 8] & kmask1) << 4)) - 32)
|
||||
+ sc[1] * y[ 16] * ((int8_t)((ql[16] & 0xF) | ((qh[16] & kmask1) << 4)) - 32)
|
||||
+ sc[1] * y[ 24] * ((int8_t)((ql[24] & 0xF) | ((qh[24] & kmask1) << 4)) - 32)
|
||||
+ sc[2] * y[ 32] * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & kmask2) << 2)) - 32)
|
||||
+ sc[2] * y[ 40] * ((int8_t)((ql[40] & 0xF) | ((qh[ 8] & kmask2) << 2)) - 32)
|
||||
+ sc[3] * y[ 48] * ((int8_t)((ql[48] & 0xF) | ((qh[16] & kmask2) << 2)) - 32)
|
||||
+ sc[3] * y[ 56] * ((int8_t)((ql[56] & 0xF) | ((qh[24] & kmask2) << 2)) - 32)
|
||||
+ sc[4] * y[ 64] * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & kmask3) << 0)) - 32)
|
||||
+ sc[4] * y[ 72] * ((int8_t)((ql[ 8] >> 4) | ((qh[ 8] & kmask3) << 0)) - 32)
|
||||
+ sc[5] * y[ 80] * ((int8_t)((ql[16] >> 4) | ((qh[16] & kmask3) << 0)) - 32)
|
||||
+ sc[5] * y[ 88] * ((int8_t)((ql[24] >> 4) | ((qh[24] & kmask3) << 0)) - 32)
|
||||
+ sc[6] * y[ 96] * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & kmask4) >> 2)) - 32)
|
||||
+ sc[6] * y[104] * ((int8_t)((ql[40] >> 4) | ((qh[ 8] & kmask4) >> 2)) - 32)
|
||||
+ sc[7] * y[112] * ((int8_t)((ql[48] >> 4) | ((qh[16] & kmask4) >> 2)) - 32)
|
||||
+ sc[7] * y[120] * ((int8_t)((ql[56] >> 4) | ((qh[24] & kmask4) >> 2)) - 32);
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
}
|
||||
|
||||
sumf += dall * result;
|
||||
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue