cuda : use amd wave sharing intrinsics for warp_reduce functions

This commit is contained in:
Engininja2 2024-04-04 15:09:03 -06:00
parent fed0108491
commit 7a3f7e94ba

View file

@ -315,6 +315,57 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}
#ifdef __HIP_PLATFORM_AMD__
#define AMD_SWIZZLE_MASK(and_mask, or_mask, xor_mask) ((and_mask) | ((or_mask)<<5) | ((xor_mask)<<10)) // 5-bit masks applied sequentially to the thread id
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
typedef union float_b32 {
float val;
int b32;
} float_b32_t;
float_b32_t tmp;
tmp.val = x;
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
return tmp.val;
}
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
a.x += __hip_ds_swizzlef(a.x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
a.y += __hip_ds_swizzlef(a.y, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
return a;
}
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, false));
return x;
}
#endif // __HIP_PLATFORM_AMD__
#endif // defined(GGML_USE_HIPBLAS)
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
@ -349,20 +400,28 @@ static __device__ void no_device_code(
#endif // __CUDA_ARCH__
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_sum_impl_amd(x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_sum_impl_amd(a);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
@ -391,11 +450,15 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_max_impl_amd(x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {