store temp KQ in registers

This commit is contained in:
Johannes Gäßler 2024-04-16 15:58:21 +02:00
parent ef9e1593f3
commit a5b0e2dea0

View file

@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
}
half2 KQ_max_new = KQ_max[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
half2 val = KQ2[j*(kqs_padded/2) + k];
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = __hmax2(KQ_max_new, val);
KQ2[j*(kqs_padded/2) + k] = val;
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
half2 val = KQ2[j*(kqs_padded/2) + k];
const half2 diff = val - KQ_max[j0/nwarps];
val = h2exp(diff);
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint *) &val) &= ftz_mask;
KQ_rowsum_add += val;
KQ2[j*(kqs_padded/2) + k] = val;
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);