From cc28742ca39b12efea5f9b8d87d44860a3430ccb Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:20:33 -0500 Subject: [PATCH] q2_k better dequant --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 27 +++++----- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 54 +++++++++---------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 1f00e442d..ed924a072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -40,7 +40,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; + const f16vec2 d = data_a[ib0 + i].d; const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dmin = d.y; @@ -48,10 +48,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes barrier(); - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2]; - uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; - uvec2 qs0 = uvec2(unpack8(qs0_u16)); - uvec2 qs16 = uvec2(unpack8(qs16_u16)); + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; @@ -66,14 +67,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index b5e8399fe..9039ce1f3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -64,47 +64,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE sc6 = scale8_f.z; const FLOAT_TYPE sc7 = scale8_f.w; - uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); - uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; qs0_16_u32_lo4 += qs0_16_lo4_offset16; qs0_16_u32_hi4 += qs0_16_hi4_offset16; qs64_80_u32_lo4 += qs64_80_lo4_offset16; qs64_80_u32_hi4 += qs64_80_hi4_offset16; - uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); - uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); - uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); - uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); - const uint32_t q4_0 = qs0_16_lo4.x; - const uint32_t q4_1 = qs0_16_lo4.y; - const uint32_t q4_2 = qs0_16_lo4.z; - const uint32_t q4_3 = qs0_16_lo4.w; - const uint32_t q4_4 = qs0_16_hi4.x; - const uint32_t q4_5 = qs0_16_hi4.y; - const uint32_t q4_6 = qs0_16_hi4.z; - const uint32_t q4_7 = qs0_16_hi4.w; - const uint32_t q4_8 = qs64_80_lo4.x; - const uint32_t q4_9 = qs64_80_lo4.y; - const uint32_t q4_10 = qs64_80_lo4.z; - const uint32_t q4_11 = qs64_80_lo4.w; - const uint32_t q4_12 = qs64_80_hi4.x; - const uint32_t q4_13 = qs64_80_hi4.y; - const uint32_t q4_14 = qs64_80_hi4.z; - const uint32_t q4_15 = qs64_80_hi4.w; + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];