fix performance regression
This commit is contained in:
parent
fa81c3a22c
commit
2272765196
1 changed files with 22 additions and 9 deletions
|
@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
half kqmax_new[ncols];
|
|
||||||
|
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||||
|
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||||
|
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||||
|
half kqmax_new = kqmax[0];
|
||||||
|
half kqmax_new_arr[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqmax_new[j] = kqmax[j];
|
kqmax_new_arr[j] = kqmax[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
sum2[j] = warp_reduce_sum(sum2[j]);
|
sum2[j] = warp_reduce_sum(sum2[j]);
|
||||||
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
||||||
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum);
|
|
||||||
|
if (ncols == 1) {
|
||||||
|
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||||
|
} else {
|
||||||
|
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
||||||
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
KQ[j*D + i_KQ] = sum;
|
KQ[j*D + i_KQ] = sum;
|
||||||
}
|
}
|
||||||
|
@ -146,9 +157,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||||
|
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
kqmax_shared[j][threadIdx.y] = kqmax_new[j];
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,11 +169,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqmax_new[j] = kqmax_shared[j][threadIdx.x];
|
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||||
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]);
|
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||||
kqmax[j] = kqmax_new[j];
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
||||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue