fix flash_attn_vec_f16 race condition

This commit is contained in:
Johannes Gäßler 2024-04-13 22:05:43 +02:00
parent 34f93bbb39
commit 6a3b84236d

View file

@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16(
VKQ += V_k*KQ2[k0/2];
}
}
__syncthreads();
}
if (tid >= D) {
@ -547,7 +549,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE;
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
const int shmem = 0;
@ -561,7 +563,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@ -572,7 +574,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
if (parallel_blocks == 1) {
return;
}