cuda : uint -> uint32_t

This commit is contained in:
Georgi Gerganov 2024-04-22 19:12:06 +03:00
parent f725ca90fb
commit 5408d55506
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 9 additions and 9 deletions

View file

@ -307,9 +307,9 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
}
#if CUDART_VERSION < 12000
static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) {
const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b));
const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b));
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b));
const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b));
return mask_low | mask_high;
}
#endif // CUDART_VERSION < 12000

View file

@ -418,8 +418,8 @@ static __global__ void flash_attn_ext_f16(
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
@ -429,8 +429,8 @@ static __global__ void flash_attn_ext_f16(
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
@ -602,8 +602,8 @@ static __global__ void flash_attn_combine_results(
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
const float KQ_max_scale = expf(diff);
const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint *) &KQ_max_scale) &= ftz_mask;
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;