cuda : uint -> uint32_t
This commit is contained in:
parent
f725ca90fb
commit
5408d55506
2 changed files with 9 additions and 9 deletions
|
@ -307,9 +307,9 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#if CUDART_VERSION < 12000
|
#if CUDART_VERSION < 12000
|
||||||
static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) {
|
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
||||||
const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b));
|
const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b));
|
||||||
const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b));
|
const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b));
|
||||||
return mask_low | mask_high;
|
return mask_low | mask_high;
|
||||||
}
|
}
|
||||||
#endif // CUDART_VERSION < 12000
|
#endif // CUDART_VERSION < 12000
|
||||||
|
|
|
@ -418,8 +418,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||||
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
*((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
||||||
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
||||||
|
|
||||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||||
|
@ -429,8 +429,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
||||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||||
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||||
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
||||||
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
||||||
}
|
}
|
||||||
|
@ -602,8 +602,8 @@ static __global__ void flash_attn_combine_results(
|
||||||
for (int l = 0; l < parallel_blocks; ++l) {
|
for (int l = 0; l < parallel_blocks; ++l) {
|
||||||
const float diff = meta[l].x - kqmax;
|
const float diff = meta[l].x - kqmax;
|
||||||
const float KQ_max_scale = expf(diff);
|
const float KQ_max_scale = expf(diff);
|
||||||
const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||||
*((uint *) &KQ_max_scale) &= ftz_mask;
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||||
|
|
||||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue