diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 963588c4f..31bfc43c1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_argsort(ctx, dst); break; case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); break; default: return false; } - if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } else { - func(ctx, tensor->src[0], tensor->src[1], tensor); - } - cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 0f135a184..bcf27fd79 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,5 +1,33 @@ #include "fattn.cuh" +#include + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; @@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment // based on metal version template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ dst, + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, float scale, int ne00, int ne01, @@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16( #endif } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh index 1b764bc94..ad3ca7a8d 100644 --- a/ggml-cuda/fattn.cuh +++ b/ggml-cuda/fattn.cuh @@ -1,6 +1,3 @@ #include "common.cuh" -void ggml_cuda_flash_attn_ext( - ggml_backend_cuda_context & ctx, - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, - const ggml_tensor * mask, ggml_tensor * KQV); +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);