CUDA: fix FA out-of-bounds reads (#7479)

This commit is contained in:
Johannes Gäßler 2024-05-23 00:31:20 +02:00 committed by GitHub
parent 1e374365d1
commit cd93a28cb1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 8 additions and 8 deletions

View file

@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}