q2_k better dequant
This commit is contained in:
parent
91f1d9ce99
commit
cc28742ca3
2 changed files with 41 additions and 40 deletions
|
@ -40,7 +40,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
const f16vec2 d = data_a[ib0 + i].d;
|
||||
const FLOAT_TYPE dall = d.x;
|
||||
const FLOAT_TYPE dmin = d.y;
|
||||
|
||||
|
@ -48,10 +48,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
|
||||
barrier();
|
||||
|
||||
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2];
|
||||
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
|
||||
uvec2 qs0 = uvec2(unpack8(qs0_u16));
|
||||
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
||||
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
||||
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
|
||||
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||
|
@ -66,14 +67,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3),
|
||||
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3),
|
||||
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3),
|
||||
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3),
|
||||
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3),
|
||||
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3),
|
||||
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3),
|
||||
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
|
||||
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
|
||||
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
|
||||
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
|
||||
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
|
||||
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
|
||||
sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
|
||||
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
|
||||
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
|
||||
|
|
|
@ -64,47 +64,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const FLOAT_TYPE sc6 = scale8_f.z;
|
||||
const FLOAT_TYPE sc7 = scale8_f.w;
|
||||
|
||||
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
||||
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
|
||||
const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
||||
const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
|
||||
|
||||
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
|
||||
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
|
||||
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
|
||||
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
|
||||
|
||||
uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
|
||||
const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
|
||||
|
||||
uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
|
||||
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
|
||||
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
|
||||
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
|
||||
const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
|
||||
const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
|
||||
const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
|
||||
const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
|
||||
|
||||
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
|
||||
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
|
||||
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
|
||||
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
|
||||
|
||||
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
|
||||
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
|
||||
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
|
||||
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
|
||||
const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
|
||||
const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
|
||||
const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
|
||||
const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
|
||||
|
||||
const uint32_t q4_0 = qs0_16_lo4.x;
|
||||
const uint32_t q4_1 = qs0_16_lo4.y;
|
||||
const uint32_t q4_2 = qs0_16_lo4.z;
|
||||
const uint32_t q4_3 = qs0_16_lo4.w;
|
||||
const uint32_t q4_4 = qs0_16_hi4.x;
|
||||
const uint32_t q4_5 = qs0_16_hi4.y;
|
||||
const uint32_t q4_6 = qs0_16_hi4.z;
|
||||
const uint32_t q4_7 = qs0_16_hi4.w;
|
||||
const uint32_t q4_8 = qs64_80_lo4.x;
|
||||
const uint32_t q4_9 = qs64_80_lo4.y;
|
||||
const uint32_t q4_10 = qs64_80_lo4.z;
|
||||
const uint32_t q4_11 = qs64_80_lo4.w;
|
||||
const uint32_t q4_12 = qs64_80_hi4.x;
|
||||
const uint32_t q4_13 = qs64_80_hi4.y;
|
||||
const uint32_t q4_14 = qs64_80_hi4.z;
|
||||
const uint32_t q4_15 = qs64_80_hi4.w;
|
||||
const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
|
||||
const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
|
||||
const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
|
||||
const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
|
||||
const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
|
||||
const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
|
||||
const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
|
||||
const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
|
||||
const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
|
||||
const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
|
||||
const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
|
||||
const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
|
||||
const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
|
||||
const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
|
||||
const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
|
||||
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue