metal : Q3_K second optimization pass - 29.6 ms/token
This commit is contained in:
parent
3d5ff127ca
commit
df2c1dc738
1 changed files with 32 additions and 29 deletions
|
@ -1092,7 +1092,18 @@ kernel void kernel_mul_mat_q3_k_f32(
|
||||||
const int shift = 2*il;
|
const int shift = 2*il;
|
||||||
const int is = 8*ip + 2*il;
|
const int is = 8*ip + 2*il;
|
||||||
|
|
||||||
//const int shift2 = 4*ip;
|
// is = 0 -> shift1 = 0, shift2 = 0, accessing aux[0] -> (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); ip = 0, il = 0
|
||||||
|
// is = 2 -> shift1 = 0, shift2 = 0, accessing aux[1] -> (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); ip = 0, il = 1
|
||||||
|
// is = 4 -> shift1 = 0, shift2 = 2, accessing aux[2] -> (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4); ip = 0, il = 2
|
||||||
|
// is = 6 -> shift1 = 0, shift2 = 2, accessing aux[3] -> (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4); ip = 0, il = 3
|
||||||
|
// is = 8 -> shift1 = 4, shift2 = 4, accessing aux[4] -> ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4); ip = 1, il = 0
|
||||||
|
// is =10 -> shift1 = 4, shift2 = 4, accessing aux[5] -> ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4); ip = 1, il = 1
|
||||||
|
// is =12 -> shift1 = 4, shift2 = 6, accessing aux[6] -> ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); ip = 1, il = 2
|
||||||
|
// is =14 -> shift1 = 4, shift2 = 6, accessing aux[7] -> ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); ip = 1, il = 3
|
||||||
|
|
||||||
|
const int s_shift1 = 4*ip;
|
||||||
|
const int s_shift2 = s_shift1 + 2*(il/2);
|
||||||
|
const int ik = 4 + (il%2);
|
||||||
|
|
||||||
uint16_t aux[8];
|
uint16_t aux[8];
|
||||||
thread const int8_t * scales = (thread const int8_t*)aux;
|
thread const int8_t * scales = (thread const int8_t*)aux;
|
||||||
|
@ -1110,49 +1121,41 @@ kernel void kernel_mul_mat_q3_k_f32(
|
||||||
device const float * y = yy + i * QK_K + y_offset;
|
device const float * y = yy + i * QK_K + y_offset;
|
||||||
|
|
||||||
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
||||||
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||||
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
|
||||||
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
|
||||||
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
|
||||||
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
|
||||||
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
|
||||||
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
|
||||||
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
|
||||||
|
|
||||||
float dl;
|
|
||||||
|
|
||||||
//dl = d_all * (scales[is+0] - 32);
|
|
||||||
//for (int l = 0; l < n; ++l) {
|
|
||||||
// sumf += y[l+ 0] * dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
|
||||||
//}
|
|
||||||
|
|
||||||
//dl = d_all * (scales[is+1] - 32);
|
|
||||||
//for (int l = 0; l < n; ++l) {
|
|
||||||
// sumf += y[l+16] * dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
|
||||||
//}
|
|
||||||
|
|
||||||
float s = 0;
|
float s = 0;
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||||
}
|
}
|
||||||
sumf += s * d_all * (scales[is+0] - 32);
|
sumf += s * d_all * (scales[0] - 32);
|
||||||
|
|
||||||
s = 0;
|
s = 0;
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||||
}
|
}
|
||||||
sumf += s * d_all * (scales[is+1] - 32);
|
sumf += s * d_all * (scales[1] - 32);
|
||||||
|
|
||||||
|
//const float d1 = d_all * (scales[0] - 32);
|
||||||
|
//for (int l = 0; l < n; ++l) {
|
||||||
|
// sumf += y[l+ 0] * d1 * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||||
|
//}
|
||||||
|
|
||||||
|
//const float d2 = d_all * (scales[1] - 32);
|
||||||
|
//for (int l = 0; l < n; ++l) {
|
||||||
|
// sumf += y[l+16] * d2 * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||||
|
//}
|
||||||
|
|
||||||
|
//float2 s = {0.f, 0.f};
|
||||||
|
//for (int l = 0; l < n; ++l) {
|
||||||
|
// s[0] += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
||||||
|
// s[1] += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
||||||
|
//}
|
||||||
|
//sumf += d_all * (s[0] * (scales[0] - 32) + s[1] * (scales[1] - 32));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sum[ith] = sumf;
|
sum[ith] = sumf;
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
//if (ith == 0) {
|
|
||||||
// for (int i = 1; i < nth; ++i) sum[0] += sum[i];
|
|
||||||
// dst[r1*ne0 + r0] = sum[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Accumulate the sum from all threads in the threadgroup
|
// Accumulate the sum from all threads in the threadgroup
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue