cuda : fix build
This commit is contained in:
parent
013721df2b
commit
6be02b5969
3 changed files with 44 additions and 16 deletions
|
@ -1,5 +1,33 @@
|
|||
#include "fattn.cuh"
|
||||
|
||||
#include <mma.h>
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
GGML_UNUSED(a);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
|
@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half>
|
|||
// based on metal version
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char* __restrict__ q,
|
||||
const char* __restrict__ k,
|
||||
const char* __restrict__ v,
|
||||
const char* __restrict__ mask,
|
||||
float* __restrict__ dst,
|
||||
const char * __restrict__ q,
|
||||
const char * __restrict__ k,
|
||||
const char * __restrict__ v,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float scale,
|
||||
int ne00,
|
||||
int ne01,
|
||||
|
@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16(
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V,
|
||||
const ggml_tensor * mask, ggml_tensor * KQV);
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue