fix flash_attn_vec_f16 race condition
This commit is contained in:
parent
34f93bbb39
commit
6a3b84236d
1 changed files with 5 additions and 3 deletions
|
@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
VKQ += V_k*KQ2[k0/2];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid >= D) {
|
||||
|
@ -547,7 +549,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
|
||||
constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE;
|
||||
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
||||
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
|
@ -561,7 +563,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||
(const char *) K->data,
|
||||
(const char *) V->data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
|
@ -572,7 +574,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if ((parallel_blocks) == 1) {
|
||||
if (parallel_blocks == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue