diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 8082e9464..7c486f482 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -225,12 +225,11 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks == 1 || tid != 0) { - return; - } + if (parallel_blocks != 1 && tid != 0) { #pragma unroll - for (int j = 0; j < ncols; ++j) { - dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); + for (int j = 0; j < ncols; ++j) { + dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); + } } #else NO_DEVICE_CODE;