From 5408d55506186b8e56f8a6a748688847ef0ebb7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 22 Apr 2024 19:12:06 +0300 Subject: [PATCH] cuda : uint -> uint32_t --- ggml-cuda/common.cuh | 6 +++--- ggml-cuda/fattn.cu | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index ac6de643d..e82d63e4a 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -307,9 +307,9 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { } #if CUDART_VERSION < 12000 -static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { - const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); - const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); return mask_low | mask_high; } #endif // CUDART_VERSION < 12000 diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4cf2907e8..2077da53d 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -418,8 +418,8 @@ static __global__ void flash_attn_ext_f16( KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; KQ_max_h2[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -429,8 +429,8 @@ static __global__ void flash_attn_ext_f16( const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } @@ -602,8 +602,8 @@ static __global__ void flash_attn_combine_results( for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; const float KQ_max_scale = expf(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint *) &KQ_max_scale) &= ftz_mask; + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y;