Fix flash-attn for AMD
This commit is contained in:
parent
f364eb6fb5
commit
7eb14d5a6b
2 changed files with 99 additions and 91 deletions
|
@ -232,80 +232,6 @@ typedef float dfloat; // dequantize float
|
||||||
typedef float2 dfloat2;
|
typedef float2 dfloat2;
|
||||||
#endif //GGML_CUDA_F16
|
#endif //GGML_CUDA_F16
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
static __device__ void no_device_code(
|
|
||||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
|
||||||
file_name, line, function_name, arch);
|
|
||||||
GGML_UNUSED(arch_list);
|
|
||||||
#else
|
|
||||||
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
|
||||||
file_name, line, function_name, arch, arch_list);
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
__trap();
|
|
||||||
|
|
||||||
GGML_UNUSED(no_device_code); // suppress unused function warning
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __CUDA_ARCH__
|
|
||||||
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
|
|
||||||
#else
|
|
||||||
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
|
|
||||||
#endif // __CUDA_ARCH__
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
|
||||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(x);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDART_VERSION < 12000
|
#if CUDART_VERSION < 12000
|
||||||
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
||||||
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
||||||
|
@ -397,6 +323,80 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||||
}
|
}
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
|
||||||
|
#else
|
||||||
|
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
|
||||||
|
#endif // __CUDA_ARCH__
|
||||||
|
|
||||||
|
[[noreturn]]
|
||||||
|
static __device__ void no_device_code(
|
||||||
|
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
||||||
|
file_name, line, function_name, arch);
|
||||||
|
GGML_UNUSED(arch_list);
|
||||||
|
#else
|
||||||
|
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
||||||
|
file_name, line, function_name, arch, arch_list);
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
__trap();
|
||||||
|
|
||||||
|
GGML_UNUSED(no_device_code); // suppress unused function warning
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(a);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
}
|
||||||
|
|
||||||
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
||||||
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,14 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#if FP16_MMA_AVAILABLE
|
#if FP16_MMA_AVAILABLE
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#include <rocwmma/rocwmma.hpp>
|
||||||
|
namespace wmma = rocwmma;
|
||||||
|
#else
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
#endif
|
namespace wmma = nvcuda::wmma;
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE 256
|
#define FATTN_KQ_STRIDE 256
|
||||||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||||
|
@ -231,11 +237,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||||
|
|
||||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||||
|
@ -319,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,20 +339,20 @@ static __global__ void flash_attn_ext_f16(
|
||||||
frag_c_KQ KQ_c[ncols/frag_n];
|
frag_c_KQ KQ_c[ncols/frag_n];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||||
frag_a_K K_a;
|
frag_a_K K_a;
|
||||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -452,7 +458,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
nvcuda::wmma::load_matrix_sync(
|
wmma::load_matrix_sync(
|
||||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||||
KQ + j0*(kqar*kqs_padded) + k,
|
KQ + j0*(kqar*kqs_padded) + k,
|
||||||
kqar*kqs_padded);
|
kqar*kqs_padded);
|
||||||
|
@ -464,7 +470,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -472,10 +478,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
|
|
||||||
frag_a_V v_a;
|
frag_a_V v_a;
|
||||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -487,10 +493,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
nvcuda::wmma::store_matrix_sync(
|
wmma::store_matrix_sync(
|
||||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||||
D_padded, nvcuda::wmma::mem_col_major);
|
D_padded, wmma::mem_col_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -863,6 +869,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // 32x8 tensor cores are not available on AMD.
|
||||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
|
@ -885,6 +892,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue