From c63dfdf765c48a0f78e162d0c02c7d69cbbc3083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:58:59 +0200 Subject: [PATCH] fix cmake build --- ggml-cuda/common.cuh | 26 ++++++++++++-------------- ggml-cuda/fattn.cu | 38 ++++++++++++-------------------------- 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index c245dd6ac..510ca6281 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -271,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} #if defined(GGML_USE_HIPBLAS) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dcd2129a2..1d29346c7 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,32 +3,6 @@ #include -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// #pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -// #else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// } - #define FATTN_KQ_STRIDE 256 template // D == head size @@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); @@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA //In this kernel Q, K, V are matrices while i, j, k are matrix indices. static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); @@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16( __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA } template // D == head size @@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results( } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } constexpr int get_max_power_of_2(int x) {