From 6a3b84236de279f0fe012cfca0c168472526b696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 13 Apr 2024 22:05:43 +0200 Subject: [PATCH] fix flash_attn_vec_f16 race condition --- ggml-cuda/fattn.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 5f1345a7f..36479b217 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16( VKQ += V_k*KQ2[k0/2]; } } + + __syncthreads(); } if (tid >= D) { @@ -547,7 +549,7 @@ template void launch_fattn_vec_f16( dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } - constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -561,7 +563,7 @@ template void launch_fattn_vec_f16( (const char *) K->data, (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -572,7 +574,7 @@ template void launch_fattn_vec_f16( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { + if (parallel_blocks == 1) { return; }