This commit is contained in:
Eve 2025-01-06 17:13:23 -05:00
parent cdf70cf27f
commit 6f5d62b098
2 changed files with 19 additions and 17 deletions

View file

@ -52,9 +52,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2; const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f)); const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h)); const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x; const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y; const FLOAT_TYPE sc1 = scale_0_4_l_f.y;

View file

@ -46,21 +46,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec2 scale0 = uvec2(unpack8(scale0_u16));
uvec2 scale4 = uvec2(unpack8(scale4_u16));
uvec2 scale8 = uvec2(unpack8(scale8_u16));
const uint32_t sc0 = ( scale0.x & 0x3f); const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t sc1 = ( scale0.y & 0x3f); const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const uint32_t sc2 = ( scale4.x & 0x3f); const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const uint32_t sc3 = ( scale4.y & 0x3f); const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);