fix performance regression
This commit is contained in:
parent
fa81c3a22c
commit
2272765196
1 changed files with 22 additions and 9 deletions
|
@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
half kqmax_new[ncols];
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||
half kqmax_new = kqmax[0];
|
||||
half kqmax_new_arr[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
kqmax_new_arr[j] = kqmax[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
sum2[j] = warp_reduce_sum(sum2[j]);
|
||||
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
||||
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum);
|
||||
|
||||
if (ncols == 1) {
|
||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||
} else {
|
||||
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
}
|
||||
|
@ -146,9 +157,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
|
||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new[j];
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,11 +169,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax_new[j] = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]);
|
||||
kqmax[j] = kqmax_new[j];
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue