fix compiler warning

This commit is contained in:
Johannes Gäßler 2024-05-09 13:51:48 +02:00
parent 2272765196
commit fece1fe482

View file

@ -225,13 +225,12 @@ static __global__ void flash_attn_vec_ext_f16(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
if (parallel_blocks == 1 || tid != 0) {
return;
}
if (parallel_blocks != 1 && tid != 0) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
}
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE