metal : Q3_K 1st optimization pass
This commit is contained in:
parent
27a69d6a75
commit
3d5ff127ca
1 changed files with 33 additions and 28 deletions
|
@ -1087,23 +1087,27 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|||
const int n = 8;
|
||||
const int l0 = n*ir;
|
||||
|
||||
const uint8_t m = 1 << (4*ip + il);
|
||||
|
||||
const int shift = 2*il;
|
||||
const int is = 8*ip + 2*il;
|
||||
|
||||
//const int shift2 = 4*ip;
|
||||
|
||||
uint16_t aux[8];
|
||||
thread const int8_t * scales = (thread const int8_t*)aux;
|
||||
|
||||
const int q_offset = 32*ip + l0;
|
||||
const int y_offset = 128*ip + 32*il + l0;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
|
||||
const float d_all = (float)(x[i].d);
|
||||
|
||||
device const uint8_t * q = x[i].qs + 32*ip + l0;
|
||||
device const uint8_t * q = x[i].qs + q_offset;
|
||||
device const uint8_t * h = x[i].hmask + l0;
|
||||
device const float * y = yy + i * QK_K + 128*ip + 32*il + l0;
|
||||
|
||||
//device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
||||
//aux[0] = (a[0] & kmask2) | (((a[2] >> 0) & kmask1) << 4);
|
||||
//aux[1] = (a[1] & kmask2) | (((a[2] >> 2) & kmask1) << 4);
|
||||
//aux[2] = ((a[0] >> 4) & kmask2) | (((a[2] >> 4) & kmask1) << 4);
|
||||
//aux[3] = ((a[1] >> 4) & kmask2) | (((a[2] >> 6) & kmask1) << 4);
|
||||
device const float * y = yy + i * QK_K + y_offset;
|
||||
|
||||
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
||||
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
||||
|
@ -1115,29 +1119,30 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|||
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
||||
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
||||
|
||||
uint8_t m = 1 << (4*ip + il);
|
||||
int is = 8*ip + 2*il;
|
||||
float dl;
|
||||
//for (int n = 0; n < QK_K; n += 128) {
|
||||
int shift = 2*il;
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sumf += y[l+ 0] * dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||
}
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sumf += y[l+16] * dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||
}
|
||||
|
||||
y += 32;
|
||||
shift += 2;
|
||||
m <<= 1;
|
||||
//}
|
||||
//q += 32;
|
||||
//dl = d_all * (scales[is+0] - 32);
|
||||
//for (int l = 0; l < n; ++l) {
|
||||
// sumf += y[l+ 0] * dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||
//}
|
||||
|
||||
//dl = d_all * (scales[is+1] - 32);
|
||||
//for (int l = 0; l < n; ++l) {
|
||||
// sumf += y[l+16] * dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||
//}
|
||||
|
||||
float s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||
}
|
||||
sumf += s * d_all * (scales[is+0] - 32);
|
||||
|
||||
s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||
}
|
||||
sumf += s * d_all * (scales[is+1] - 32);
|
||||
|
||||
}
|
||||
|
||||
sum[ith] = sumf;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue