store temp KQ in registers
This commit is contained in:
parent
ef9e1593f3
commit
a5b0e2dea0
1 changed files with 16 additions and 10 deletions
|
@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
}
|
||||||
|
|
||||||
half2 KQ_max_new = KQ_max[j0/nwarps];
|
half2 KQ_max_new = KQ_max[j0/nwarps];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
half2 val = KQ2[j*(kqs_padded/2) + k];
|
|
||||||
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||||
KQ_max_new = __hmax2(KQ_max_new, val);
|
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||||
KQ2[j*(kqs_padded/2) + k] = val;
|
|
||||||
}
|
}
|
||||||
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||||
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
|
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
|
||||||
|
@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
half2 val = KQ2[j*(kqs_padded/2) + k];
|
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
|
||||||
const half2 diff = val - KQ_max[j0/nwarps];
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||||
val = h2exp(diff);
|
|
||||||
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
*((uint *) &val) &= ftz_mask;
|
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||||
KQ_rowsum_add += val;
|
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
||||||
KQ2[j*(kqs_padded/2) + k] = val;
|
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
||||||
}
|
}
|
||||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue