gla: Put the barrier inside the main logic loop
This commit is contained in:
parent
81d852902e
commit
db3ded2c8a
1 changed files with 3 additions and 1 deletions
|
@ -30,9 +30,11 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
|
||||||
for (u_int i = 0; i < head_size; i++) {
|
for (u_int i = 0; i < head_size; i++) {
|
||||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
}
|
}
|
||||||
item.barrier(sycl::access::fence_space::local_space); //sync threads
|
|
||||||
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||||
|
|
||||||
|
item.barrier(sycl::access::fence_space::local_space); //sync threads
|
||||||
_k[tid] = k[t];
|
_k[tid] = k[t];
|
||||||
_r[tid] = r[t];
|
_r[tid] = r[t];
|
||||||
_td[tid] = td[t];
|
_td[tid] = td[t];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue