cuda : try to fix __hgt2_mask

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-22 21:42:43 +03:00
parent c70bfd7bcb
commit c129369702
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -308,8 +308,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if CUDART_VERSION < 12000 #if CUDART_VERSION < 12000
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
const uint32_t mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
const uint32_t mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
return mask_low | mask_high; return mask_low | mask_high;
} }
#endif // CUDART_VERSION < 12000 #endif // CUDART_VERSION < 12000