fix batch size 2-8

This commit is contained in:
Johannes Gäßler 2024-05-04 22:01:51 +02:00
parent 617f129e43
commit d9bcb92f75

View file

@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
kqsum[j] = kqsum_shared[j][threadIdx.x]; kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j] = warp_reduce_sum(kqsum[j]); kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
half dst_val = (__low2half(VKQ[j]) + __high2half(VKQ[j])); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (parallel_blocks == 1) { if (parallel_blocks == 1) {
dst_val /= kqsum[j]; dst_val /= kqsum[j_VKQ];
} }
dst[D*gridDim.y*(blockIdx.x*ncols + j) + D*blockIdx.y + tid] = dst_val; const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
} }
if (parallel_blocks == 1 || tid != 0) { if (parallel_blocks == 1 || tid != 0) {