From 6145fc79e5117959e49a667ea76f72649922e705 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:41:50 -0500 Subject: [PATCH] q2_k separate out --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 130 ++++++++++-------- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 4 +- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 129442f7e..99db8f32b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -7,6 +7,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16]; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint itid8, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes + sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); + sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + const f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -27,68 +95,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - - sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes - sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes - barrier(); - - const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); - const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); - const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); - const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - - const f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = d.x; - const FLOAT_TYPE dmin = d.y; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; - B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); - } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index e395b6143..6d655205d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -16,7 +16,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); if (i >= num_blocks_per_row) @@ -38,7 +38,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); if (all_threads) { - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); }