add [[unroll]] and remove unnecessary conditions
This commit is contained in:
parent
64c16c4ae0
commit
6ea605ddfc
1 changed files with 6 additions and 4 deletions
|
@ -1,5 +1,7 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
@ -29,12 +31,12 @@ void main() {
|
|||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (tid >= head_size || batch_id >= B || head_id >= H) {
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
@ -56,7 +58,7 @@ void main() {
|
|||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
|
@ -78,7 +80,7 @@ void main() {
|
|||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue