From 3452095089918abf296b653bf93ac8afde7942ce Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 22 Jul 2023 17:46:52 +0200 Subject: [PATCH] Unroll loops in dmmv shader --- vk_shaders/dequant_mul_mat_vec_q4_0.glsl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vk_shaders/dequant_mul_mat_vec_q4_0.glsl b/vk_shaders/dequant_mul_mat_vec_q4_0.glsl index d713497ea..160d603d8 100644 --- a/vk_shaders/dequant_mul_mat_vec_q4_0.glsl +++ b/vk_shaders/dequant_mul_mat_vec_q4_0.glsl @@ -33,11 +33,11 @@ void main() { const int row = int(gl_WorkGroupID.x); const int tid = int(gl_LocalInvocationID.x); - const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + const int y_offset = QUANT_K/2; tmp[tid] = 0; - for (int i = 0; i < p.ncols/block_size; i += 2) { + [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; const int ib = (row*p.ncols + col)/QUANT_K; // block index const int iqs = (col%QUANT_K)/QUANT_R; // quant index @@ -61,7 +61,7 @@ void main() { // sum up partial sums and write back result barrier(); - for (int s=block_size/2; s>0; s>>=1) { + [[unroll]] for (int s=block_size/2; s>0; s>>=1) { if (tid < s) { tmp[tid] += tmp[tid + s]; }