From ed038a26e37da36a3df5eba8be3cd8a199ecbab5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 31 Dec 2024 23:00:16 -0500 Subject: [PATCH] bct --- .../ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 2d9fd8eb0..d8ce4ed9f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -25,13 +25,12 @@ uint fill_blkcache_its(uint wg_size) { return 8; } -void fill_blkcache(const int num_blocks, const uint ib0, const uint i0, const uint tid, const uint fbi) { - uint bc_t = 104 / fbi; - if (tid < bc_t) { +void fill_blkcache(const int num_blocks, const uint bct, const uint ib0, const uint i0, const uint tid, const uint fbi) { + if (tid < bct) { [[unroll]] for (int l = 0; l < num_blocks; ++l) { [[unroll]] for (int m = 0; m < fbi; ++m) // cache full superblock into shared memory with coalesced reads - blkcache[l].blk[tid + m*bc_t] = data_a_packed16[ib0 + i0 + l].blk[tid + m*bc_t]; + blkcache[l].blk[tid + m*bct] = data_a_packed16[ib0 + i0 + l].blk[tid + m*bct]; } } } @@ -48,6 +47,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid = tid%16; // 0...15 const uint ix = tid/16; const uint fbi = fill_blkcache_its(gl_WorkGroupSize.x); + const uint bct = 104/fbi; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = itid - 8*v_im; // 0...15 or 0...7 @@ -81,9 +81,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // fill_blkcache is sensitive to unrolling with hardcoded it_size if (blim == it_size) { - fill_blkcache(int(it_size), ib0, i0, tid, fbi); + fill_blkcache(int(it_size), bct, ib0, i0, tid, fbi); } else { - fill_blkcache(blim, ib0, i0, tid, fbi); + fill_blkcache(blim, bct, ib0, i0, tid, fbi); } sccache[ix][itid] = FLOAT_TYPE(int8_t(bitfieldExtract(blkcache[ix].blk[96 + itid/2], int(bcs_offset), 8)));