q3_k use hmask simd from cpu avx version

This commit is contained in:
Eve 2025-01-08 21:13:09 -05:00
parent fe71a8c4a1
commit 923e9a8377

View file

@ -21,9 +21,13 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint itid8 = itid%8;
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_im4 = v_im*4;
const uint v_in = itid - 8*v_im; // 0...7
const uint8_t m = uint8_t(1 << (4 * v_im));
const uint32_t m = 0x01010101 << (4 * v_im);
uint32_t hm_m[4];
[[unroll]] for (uint j = 0; j < 4; ++j)
hm_m[j] = m << j;
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
@ -44,7 +48,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32);
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32);
barrier();
// 0, 1, 16, 17
@ -55,8 +59,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in]));
const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8]));
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
@ -71,14 +78,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4),
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4),
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum))))))));
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}