diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp index 8beb7ff6e..35cc6c45f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -1,5 +1,7 @@ #version 450 +#extension GL_EXT_control_flow_attributes : require + #define BLOCK_SIZE 64 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; @@ -29,12 +31,12 @@ void main() { const uint state_size = C * head_size; const uint n_seq_tokens = T / B; - if (tid >= head_size || batch_id >= B || head_id >= H) { + if (batch_id >= B || head_id >= H) { return; } A_TYPE state[BLOCK_SIZE]; - for (uint i = 0; i < head_size; i++) { + [[unroll]] for (uint i = 0; i < head_size; i++) { state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + i * head_size + tid]; } @@ -56,7 +58,7 @@ void main() { const A_TYPE v_val = v[t]; A_TYPE y = 0.0; - for (uint j = 0; j < head_size; j += 4) { + [[unroll]] for (uint j = 0; j < head_size; j += 4) { vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); @@ -78,7 +80,7 @@ void main() { dst[t] = y; } - for (uint i = 0; i < head_size; i++) { + [[unroll]] for (uint i = 0; i < head_size; i++) { dst[T * C + batch_id * state_size + head_id * head_size * head_size + i * head_size + tid] = state[i]; }