diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 455510684..8082e9464 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16( const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: - half kqmax_new[ncols]; + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { - kqmax_new[j] = kqmax[j]; + kqmax_new_arr[j] = kqmax[j]; } #pragma unroll @@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16( sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + if (threadIdx.x == 0) { KQ[j*D + i_KQ] = sum; } @@ -146,9 +157,11 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - kqmax_new[j] = warp_reduce_max(kqmax_new[j]); + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new[j]; + kqmax_shared[j][threadIdx.y] = kqmax_new_j; } } @@ -156,11 +169,11 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - kqmax_new[j] = kqmax_shared[j][threadIdx.x]; - kqmax_new[j] = warp_reduce_max(kqmax_new[j]); + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]); - kqmax[j] = kqmax_new[j]; + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; const half val = hexp(KQ[j*D + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val;