fix performance regression

This commit is contained in:
Johannes Gäßler 2024-05-09 12:22:31 +02:00
parent fa81c3a22c
commit 2272765196

View file

@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols];
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
half kqmax_new = kqmax[0];
half kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = kqmax[j];
kqmax_new_arr[j] = kqmax[j];
}
#pragma unroll
@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum);
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
} else {
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
}
if (threadIdx.x == 0) {
KQ[j*D + i_KQ] = sum;
}
@ -146,9 +157,11 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new[j];
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
}
@ -156,11 +169,11 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = kqmax_shared[j][threadIdx.x];
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]);
kqmax[j] = kqmax_new[j];
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale + val;