From 25d7ae429df749b7f142f29a5cf77f456e7bffcc Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:44:59 -0500 Subject: [PATCH] go even further --- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index ccaa0486d..f1464b899 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -10,6 +10,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 1) const uint NUM_ROWS = 1; shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16]; void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { @@ -57,12 +58,13 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // cache full superblock into shared memory with coalesced reads // we assume 64 threads here! - [[unroll]] for (int l = 0; (l < 4) && (i0 + l < num_blocks_per_row); ++l) { - blkcache[l].ql[tid] = data_a_packed16[ib0 + i0 + l].ql[tid]; - // hacky method of reading beyond ql and the block struct size but it looks like vulkan doesn't care? o_O - // this assumes that the struct is packed in continous 16 bit blocks to work - blkcache[l].ql[64 + tid] = data_a_packed16[ib0 + i0 + l].ql[64 + tid]; + // + // hacky method of reading beyond ql and the block struct size but it looks like vulkan doesn't care? o_O + // this assumes that the struct is packed in continous 16 bit blocks to work + [[unroll]] for (int l = 0; l < 7; ++l) { + blkcache[0].ql[tid + 64*l] = data_a_packed16[ib0 + i0].ql[tid + 64*l]; } + sccache[ix][itid] = FLOAT_TYPE(blkcache[ix].scales[itid]); barrier(); if (i >= num_blocks_per_row) continue; @@ -100,7 +102,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { } [[unroll]] for (int l = 0; l < 4; ++l) - sum[l] *= FLOAT_TYPE(blkcache[ix].scales[s_offset + l*2]); + sum[l] *= sccache[ix][s_offset + l*2]; temp[n] += (sum[0] + sum[1] + sum[2] + sum[3]) * FLOAT_TYPE(blkcache[ix].d); } }