CUDA: Enable FP16_MMA for RDNA3 with rocWMMA
This commit is contained in:
parent
70392f1f81
commit
8f1cd4ada6
5 changed files with 36 additions and 31 deletions
|
@ -177,7 +177,7 @@ set(CMAKE_C_STANDARD_REQUIRED true)
|
|||
if (GGML_SYCL)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
else()
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
|
||||
|
|
|
@ -131,6 +131,9 @@ typedef float2 dfloat2;
|
|||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
#define INT8_MMA_AVAILABLE
|
||||
|
@ -145,7 +148,7 @@ static constexpr bool fast_fp16_available(const int cc) {
|
|||
}
|
||||
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||
return (cc < CC_OFFSET_AMD && cc >= CC_VOLTA) || cc >= CC_RDNA3;
|
||||
}
|
||||
|
||||
static constexpr bool int8_mma_available(const int cc) {
|
||||
|
@ -242,8 +245,6 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
#if CUDART_VERSION >= CUDART_HMAX
|
||||
return __hmax2(a, b);
|
||||
#else
|
||||
|
@ -252,12 +253,6 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|||
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
||||
return ret;
|
||||
#endif // CUDART_VERSION >= CUDART_HMAX
|
||||
|
||||
#else
|
||||
GGML_UNUSED(a);
|
||||
GGML_UNUSED(b);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
|
|
|
@ -2,7 +2,13 @@
|
|||
#include "fattn-common.cuh"
|
||||
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
#include <rocwmma/rocwmma.hpp>
|
||||
namespace wmma = ::rocwmma;
|
||||
#else
|
||||
#include <mma.h>
|
||||
namespace wmma = ::nvcuda::wmma;
|
||||
#endif
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
|
@ -63,11 +69,11 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
|
@ -157,7 +163,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,20 +177,20 @@ static __global__ void flash_attn_ext_f16(
|
|||
frag_c_KQ KQ_c[ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
wmma::fill_fragment(KQ_c[j], KQ_acc_t(0.0f));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -303,7 +309,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
|
@ -315,7 +321,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], half(0.0f));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
@ -323,10 +329,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -338,10 +344,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, nvcuda::wmma::mem_col_major);
|
||||
D_padded, wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
constexpr int cols_per_block = 8;
|
||||
switch (Q->ne[0]) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
|
@ -85,6 +86,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
|
@ -305,7 +307,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
if (fp16_mma_available(cc) && (Q->ne[1] > 8 || Q->ne[0] % WARP_SIZE != 0)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
} else if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
#include "../fattn-wmma-f16.cuh"
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE(64, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
||||
//DECL_FATTN_WMMA_F16_CASE(64, 8, half);
|
||||
//DECL_FATTN_WMMA_F16_CASE(96, 8, half);
|
||||
//DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
||||
//DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue