From 6f5d62b098a45d9c4a0d833d03ec68848b0c06b5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:13:23 -0500 Subject: [PATCH] q5_k --- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 ++-- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 30 ++++++++++--------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6761f3d32..28cde16db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -52,9 +52,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; - const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2; - const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f)); - const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h)); + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); const FLOAT_TYPE sc0 = scale_0_4_l_f.x; const FLOAT_TYPE sc1 = scale_0_4_l_f.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index c4aaf9fea..b5e8399fe 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -46,21 +46,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; - uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u16)); - uvec2 scale4 = uvec2(unpack8(scale4_u16)); - uvec2 scale8 = uvec2(unpack8(scale8_u16)); + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);