cuda : optimize argmax (#10441)

* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Diego Devesa 2024-11-21 18:18:50 +01:00 committed by GitHub
parent 1bb30bf28c
commit a5e47592b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 110 additions and 67 deletions

View file

@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
}
return a;
}
@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
reinterpret_cast<half&>(a.x) += __low2half(a_other);
reinterpret_cast<half&>(a.y) += __high2half(a_other);
}
return a;
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
}
return a;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}
@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
#else