CUDA: fix softmax compile for old CUDA versions

(cherry picked from commit 09f771312392c9573fa01ce4ccfdc9edecce2eb9)
This commit is contained in:
JohannesGaessler 2024-01-10 18:27:00 +01:00 committed by Concedo
parent 941e70db14
commit ddc06843f1

View file

@ -116,6 +116,8 @@
#include "ggml.h"
#include "ggml-backend-impl.h"
#define CUDART_CI 11070 // CUDA 11.7, version used for the Github CI
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
@ -597,16 +599,16 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) a;
bad_arch();
#else
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
#else
(void) a;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
@ -618,16 +620,16 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) x;
bad_arch();
#else
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
#else
(void) x;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
@ -5416,7 +5418,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
@ -5541,7 +5543,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
#else
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
}
template <bool vals_smem, int ncols_template, int block_size_template>
@ -8344,15 +8346,15 @@ static void ggml_cuda_op_soft_max(
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
const bool use_f16_soft_max = false;
#else
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_CI
#ifdef GGML_CUDA_F16
const bool use_f16_soft_max = true;
#else
const bool use_f16_soft_max = false;
#endif // GGML_CUDA_F16
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#else
const bool use_f16_soft_max = false;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_CI
if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);