add [[unroll]] and remove unnecessary conditions

This commit is contained in:
Zhiyuan Li 2024-12-16 17:35:44 +08:00
parent 64c16c4ae0
commit 6ea605ddfc

View file

@ -1,5 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
@ -29,12 +31,12 @@ void main() {
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
if (tid >= head_size || batch_id >= B || head_id >= H) {
if (batch_id >= B || head_id >= H) {
return;
}
A_TYPE state[BLOCK_SIZE];
for (uint i = 0; i < head_size; i++) {
[[unroll]] for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
@ -56,7 +58,7 @@ void main() {
const A_TYPE v_val = v[t];
A_TYPE y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
@ -78,7 +80,7 @@ void main() {
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
[[unroll]] for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}