fix compiler warning
This commit is contained in:
parent
2272765196
commit
fece1fe482
1 changed files with 4 additions and 5 deletions
|
@ -225,12 +225,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks == 1 || tid != 0) {
|
if (parallel_blocks != 1 && tid != 0) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue