From 8f1cd4ada66372cb631351b05736358834d07537 Mon Sep 17 00:00:00 2001 From: Ivan Chikish Date: Sun, 22 Sep 2024 19:58:15 +0300 Subject: [PATCH] CUDA: Enable FP16_MMA for RDNA3 with rocWMMA --- ggml/CMakeLists.txt | 2 +- ggml/src/ggml-cuda/common.cuh | 13 ++----- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 38 +++++++++++-------- ggml/src/ggml-cuda/fattn.cu | 6 ++- .../fattn-wmma-f16-instance-kqhalf-cpb8.cu | 8 ++-- 5 files changed, 36 insertions(+), 31 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 89fdf9d1c..71bd04e8e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -177,7 +177,7 @@ set(CMAKE_C_STANDARD_REQUIRED true) if (GGML_SYCL) set(CMAKE_CXX_STANDARD 17) else() - set(CMAKE_CXX_STANDARD 11) + set(CMAKE_CXX_STANDARD 17) endif() set(CMAKE_CXX_STANDARD_REQUIRED true) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 6a4bcdba0..86d09cfb3 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -131,6 +131,9 @@ typedef float2 dfloat2; #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING #define INT8_MMA_AVAILABLE @@ -145,7 +148,7 @@ static constexpr bool fast_fp16_available(const int cc) { } static constexpr bool fp16_mma_available(const int cc) { - return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; + return (cc < CC_OFFSET_AMD && cc >= CC_VOLTA) || cc >= CC_RDNA3; } static constexpr bool int8_mma_available(const int cc) { @@ -242,8 +245,6 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b } static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) - #if CUDART_VERSION >= CUDART_HMAX return __hmax2(a, b); #else @@ -252,12 +253,6 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; #endif // CUDART_VERSION >= CUDART_HMAX - -#else - GGML_UNUSED(a); - GGML_UNUSED(b); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) } static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index b10d19d93..0d0b7d620 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,7 +2,13 @@ #include "fattn-common.cuh" #ifdef FP16_MMA_AVAILABLE +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#include +namespace wmma = ::rocwmma; +#else #include +namespace wmma = ::nvcuda::wmma; +#endif #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: @@ -63,11 +69,11 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -157,7 +163,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } } @@ -171,20 +177,20 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], KQ_acc_t(0.0f)); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); } } @@ -303,7 +309,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( + wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); @@ -315,7 +321,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], half(0.0f)); } #pragma unroll @@ -323,10 +329,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } @@ -338,10 +344,10 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( + wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + D_padded, wmma::mem_col_major); } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 83e5589a1..75810f644 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -73,6 +73,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) case 64: ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); break; @@ -85,6 +86,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g case 256: ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; +#endif default: GGML_ABORT("fatal error"); break; @@ -305,7 +307,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + if (fp16_mma_available(cc) && (Q->ne[1] > 8 || Q->ne[0] % WARP_SIZE != 0)) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + } else if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index d30710c5f..ca7ecdf1a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -2,7 +2,7 @@ #include "../fattn-wmma-f16.cuh" -DECL_FATTN_WMMA_F16_CASE(64, 8, half); -DECL_FATTN_WMMA_F16_CASE(96, 8, half); -DECL_FATTN_WMMA_F16_CASE(128, 8, half); -DECL_FATTN_WMMA_F16_CASE(256, 8, half); +//DECL_FATTN_WMMA_F16_CASE(64, 8, half); +//DECL_FATTN_WMMA_F16_CASE(96, 8, half); +//DECL_FATTN_WMMA_F16_CASE(128, 8, half); +//DECL_FATTN_WMMA_F16_CASE(256, 8, half);