CUDA: fix FA out-of-bounds writes (#7465)

This commit is contained in:
Johannes Gäßler 2024-05-22 17:58:25 +02:00 committed by GitHub
parent b18532a4ef
commit 38c03478a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 2 deletions

View file

@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ic0 + j_VKQ >= ne01) {
break;
}
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols) {
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
}