From d9bcb92f759dcb53186c864b8b7b07d9ef1f6934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 4 May 2024 22:01:51 +0200 Subject: [PATCH] fix batch size 2-8 --- ggml-cuda/fattn.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 444ef5b74..ca760b0a1 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); #pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = kqsum_shared[j][threadIdx.x]; - kqsum[j] = warp_reduce_sum(kqsum[j]); + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - half dst_val = (__low2half(VKQ[j]) + __high2half(VKQ[j])); + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); if (parallel_blocks == 1) { - dst_val /= kqsum[j]; + dst_val /= kqsum[j_VKQ]; } - dst[D*gridDim.y*(blockIdx.x*ncols + j) + D*blockIdx.y + tid] = dst_val; + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } if (parallel_blocks == 1 || tid != 0) {