From a5b0e2dea018cfac5ee478aac0d780eef391b30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 15:58:21 +0200 Subject: [PATCH] store temp KQ in registers --- ggml-cuda/fattn.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f6289822e..b889cdb3b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, val); - KQ2[j*(kqs_padded/2) + k] = val; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; @@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - const half2 diff = val - KQ_max[j0/nwarps]; - val = h2exp(diff); + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &val) &= ftz_mask; - KQ_rowsum_add += val; - KQ2[j*(kqs_padded/2) + k] = val; + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);