fix cmake build
This commit is contained in:
parent
bb0d51accd
commit
c63dfdf765
2 changed files with 24 additions and 40 deletions
|
@ -271,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_CUDA_F16
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -284,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
#endif // GGML_CUDA_F16
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -294,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
//#pragma unroll
|
#pragma unroll
|
||||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
// }
|
}
|
||||||
// return x;
|
return x;
|
||||||
//#else
|
#else
|
||||||
// GGML_UNUSED(x);
|
GGML_UNUSED(x);
|
||||||
// NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
//}
|
}
|
||||||
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
|
|
|
@ -3,32 +3,6 @@
|
||||||
|
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
||||||
// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
||||||
// }
|
|
||||||
// return x;
|
|
||||||
// #else
|
|
||||||
// GGML_UNUSED(x);
|
|
||||||
// NO_DEVICE_CODE;
|
|
||||||
// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
// }
|
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE 256
|
#define FATTN_KQ_STRIDE 256
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
|
||||||
|
@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
|
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
||||||
|
@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||||
|
@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results(
|
||||||
const float * __restrict__ VKQ_parts,
|
const float * __restrict__ VKQ_parts,
|
||||||
const half2 * __restrict__ VKQ_meta,
|
const half2 * __restrict__ VKQ_meta,
|
||||||
float * __restrict__ dst) {
|
float * __restrict__ dst) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
__builtin_assume(tid < D);
|
__builtin_assume(tid < D);
|
||||||
|
@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results(
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int get_max_power_of_2(int x) {
|
constexpr int get_max_power_of_2(int x) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue