CUDA: Enable FP16_MMA for RDNA3 with rocWMMA

This commit is contained in:
Ivan Chikish 2024-09-22 19:58:15 +03:00
parent 70392f1f81
commit 8f1cd4ada6
5 changed files with 36 additions and 31 deletions

View file

@ -177,7 +177,7 @@ set(CMAKE_C_STANDARD_REQUIRED true)
if (GGML_SYCL) if (GGML_SYCL)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
else() else()
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 17)
endif() endif()
set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_CXX_STANDARD_REQUIRED true)

View file

@ -131,6 +131,9 @@ typedef float2 dfloat2;
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE #define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#define INT8_MMA_AVAILABLE #define INT8_MMA_AVAILABLE
@ -145,7 +148,7 @@ static constexpr bool fast_fp16_available(const int cc) {
} }
static constexpr bool fp16_mma_available(const int cc) { static constexpr bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; return (cc < CC_OFFSET_AMD && cc >= CC_VOLTA) || cc >= CC_RDNA3;
} }
static constexpr bool int8_mma_available(const int cc) { static constexpr bool int8_mma_available(const int cc) {
@ -242,8 +245,6 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
} }
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
#if CUDART_VERSION >= CUDART_HMAX #if CUDART_VERSION >= CUDART_HMAX
return __hmax2(a, b); return __hmax2(a, b);
#else #else
@ -252,12 +253,6 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
return ret; return ret;
#endif // CUDART_VERSION >= CUDART_HMAX #endif // CUDART_VERSION >= CUDART_HMAX
#else
GGML_UNUSED(a);
GGML_UNUSED(b);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
} }
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {

View file

@ -2,7 +2,13 @@
#include "fattn-common.cuh" #include "fattn-common.cuh"
#ifdef FP16_MMA_AVAILABLE #ifdef FP16_MMA_AVAILABLE
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#include <rocwmma/rocwmma.hpp>
namespace wmma = ::rocwmma;
#else
#include <mma.h> #include <mma.h>
namespace wmma = ::nvcuda::wmma;
#endif
#endif // FP16_MMA_AVAILABLE #endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel: // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@ -63,11 +69,11 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@ -157,7 +163,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) { for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
} }
} }
@ -171,20 +177,20 @@ static __global__ void flash_attn_ext_f16(
frag_c_KQ KQ_c[ncols/frag_n]; frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); wmma::fill_fragment(KQ_c[j], KQ_acc_t(0.0f));
} }
#pragma unroll #pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a; frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
} }
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
} }
} }
@ -303,7 +309,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16; const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync( wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k, KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded); kqar*kqs_padded);
@ -315,7 +321,7 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], half(0.0f));
} }
#pragma unroll #pragma unroll
@ -323,10 +329,10 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16; const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a; frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
} }
} }
} }
@ -338,10 +344,10 @@ static __global__ void flash_attn_ext_f16(
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync( wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major); D_padded, wmma::mem_col_major);
} }
} }

View file

@ -73,6 +73,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
switch (Q->ne[0]) { switch (Q->ne[0]) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
case 64: case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break; break;
@ -85,6 +86,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
case 256: case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break; break;
#endif
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break; break;
@ -305,7 +307,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// On AMD the tile kernels perform poorly, use the vec kernel instead: // On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) { if (cc >= CC_OFFSET_AMD) {
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { if (fp16_mma_available(cc) && (Q->ne[1] > 8 || Q->ne[0] % WARP_SIZE != 0)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
} else if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else { } else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

View file

@ -2,7 +2,7 @@
#include "../fattn-wmma-f16.cuh" #include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 8, half); //DECL_FATTN_WMMA_F16_CASE(64, 8, half);
DECL_FATTN_WMMA_F16_CASE(96, 8, half); //DECL_FATTN_WMMA_F16_CASE(96, 8, half);
DECL_FATTN_WMMA_F16_CASE(128, 8, half); //DECL_FATTN_WMMA_F16_CASE(128, 8, half);
DECL_FATTN_WMMA_F16_CASE(256, 8, half); //DECL_FATTN_WMMA_F16_CASE(256, 8, half);