From d70a731639d9acd0baa588a94bfaa5f928b26b9c Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 20:48:27 -0500 Subject: [PATCH] q2_k --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 50 ++++++++----------- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 5 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 +-- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 921759bc8..098771493 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,6 +5,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16]; + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -16,6 +18,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; const uint itid = tid%16; // 0...15 const uint ix = tid/16; + const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = itid - 8*v_im; // 0...15 or 0...7 @@ -42,18 +45,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dmin = d.y; - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4]; - uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; - - uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; - uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; - uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; - uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; - - uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); - uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); - uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); - uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); + sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes + sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes + barrier(); uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2]; uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; @@ -73,22 +67,22 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); } temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 0266417a8..0658f46bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,7 +5,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared uint32_t scu8[BLOCK_SIZE/16][12]; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][12]; void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { @@ -61,8 +60,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 225f0ce70..8a3644e96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -47,9 +47,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; uvec2 scale0 = uvec2(unpack8(scale0_u32)); uvec2 scale4 = uvec2(unpack8(scale4_u32)); uvec2 scale8 = uvec2(unpack8(scale8_u32));