Better Q3_K for QK_K = 64
21.6 ms/t -> 21.1 ms/t
This commit is contained in:
parent
0099570f04
commit
d3c3624c7b
1 changed files with 15 additions and 16 deletions
|
@ -1479,8 +1479,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
const int im = il/8; // 0, 0, 1, 1
|
const int im = il/8; // 0, 0, 1, 1
|
||||||
const int in = il%8; // 0, 4, 0, 4
|
const int in = il%8; // 0, 4, 0, 4
|
||||||
|
|
||||||
float4 sum1 = {0.f, 0.f, 0.f, 0.f};
|
float2 sum = {0.f, 0.f};
|
||||||
float4 sum2 = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
|
|
||||||
for (int i = ix; i < nb; i += 8) {
|
for (int i = ix; i < nb; i += 8) {
|
||||||
|
|
||||||
|
@ -1488,28 +1487,28 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
||||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
||||||
|
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
||||||
device const float * y = yy + i * QK_K + il;
|
device const float * y = yy + i * QK_K + il;
|
||||||
|
|
||||||
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
||||||
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
||||||
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
||||||
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
||||||
|
|
||||||
for (int l = 0; l < 4; l += 2) {
|
for (int l = 0; l < 4; l += 2) {
|
||||||
const uint16_t hm = h[l/2] >> im;
|
const uint16_t hm = h[l/2] >> im;
|
||||||
sum1[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4));
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
||||||
sum1[1] += y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16));
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
||||||
sum1[2] += y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64));
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
||||||
sum1[3] += y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
||||||
sum2[0] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024));
|
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
||||||
sum2[1] += y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096));
|
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
||||||
sum2[2] += y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384));
|
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
||||||
sum2[3] += y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
const float sumf = sum1[0] + sum1[1] * 1.f/4.f + sum1[2] * 1.f/16.f + sum1[3] * 1.f/64.f +
|
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
||||||
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
|
|
||||||
|
|
||||||
const float tot = simd_sum(sumf);
|
const float tot = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue