cuda : add amd dpp version of warp_reduce_sum for half2
This commit is contained in:
parent
7a3f7e94ba
commit
9e6f2e2aff
1 changed files with 37 additions and 7 deletions
|
@ -321,6 +321,9 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||||
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
|
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
|
||||||
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
|
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
|
||||||
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
|
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
|
||||||
|
#define hip_move_dpph2(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
|
||||||
|
hip_move_dpph2_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
|
||||||
|
#define hip_ds_swizzleh2(src, pattern) hip_ds_swizzleh2_N<(pattern)>((src))
|
||||||
|
|
||||||
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
|
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
|
||||||
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
|
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
|
||||||
|
@ -334,6 +337,30 @@ static __device__ __forceinline__ float hip_move_dppf_N(float x) {
|
||||||
return tmp.val;
|
return tmp.val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
|
||||||
|
static __device__ __forceinline__ half2 hip_move_dpph2_N(half2 x) {
|
||||||
|
typedef union half2_b32 {
|
||||||
|
half2 val;
|
||||||
|
int b32;
|
||||||
|
} half2_b32_t;
|
||||||
|
half2_b32_t tmp;
|
||||||
|
tmp.val = x;
|
||||||
|
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
|
||||||
|
return tmp.val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int pattern>
|
||||||
|
static __device__ __forceinline__ half2 hip_ds_swizzleh2_N(half2 src) {
|
||||||
|
typedef union half2_b32 {
|
||||||
|
half2 val;
|
||||||
|
int b32;
|
||||||
|
} half2_b32_t;
|
||||||
|
half2_b32_t tmp;
|
||||||
|
tmp.val = src;
|
||||||
|
tmp.b32 = __builtin_amdgcn_ds_swizzle(tmp.b32, pattern);
|
||||||
|
return tmp.val;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
|
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
|
||||||
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
|
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
|
||||||
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
|
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
|
||||||
|
@ -357,6 +384,15 @@ static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 x) {
|
||||||
|
x += hip_ds_swizzleh2(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
|
||||||
|
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
|
||||||
|
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
|
||||||
|
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
|
||||||
|
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
|
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
|
||||||
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
|
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
|
||||||
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
|
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
|
||||||
|
@ -428,13 +464,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#if FP16_AVAILABLE
|
#if FP16_AVAILABLE
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#pragma unroll
|
return warp_reduce_sum_impl_amd(a);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
|
||||||
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
|
||||||
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue