fix cmake build
This commit is contained in:
parent
bb0d51accd
commit
c63dfdf765
2 changed files with 24 additions and 40 deletions
|
@ -3,32 +3,6 @@
|
|||
|
||||
#include <mma.h>
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
GGML_UNUSED(a);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
// #pragma unroll
|
||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
// }
|
||||
// return x;
|
||||
// #else
|
||||
// GGML_UNUSED(x);
|
||||
// NO_DEVICE_CODE;
|
||||
// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
// }
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
|
@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
|
||||
|
@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
||||
|
@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
|
@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
|
@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results(
|
|||
const float * __restrict__ VKQ_parts,
|
||||
const half2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results(
|
|||
}
|
||||
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue