gla: Put the barrier inside the main logic loop

This commit is contained in:
Akarshan Biswas 2025-01-13 17:44:07 +05:30
parent 81d852902e
commit db3ded2c8a
No known key found for this signature in database
GPG key ID: 52A578A14B32134D

View file

@ -30,9 +30,11 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
for (u_int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
item.barrier(sycl::access::fence_space::local_space); //sync threads
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
item.barrier(sycl::access::fence_space::local_space); //sync threads
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];