Fix flash-attn for AMD

This commit is contained in:
Johannes Gäßler 2024-04-19 13:50:17 -04:00 committed by Jerome
parent f364eb6fb5
commit 7eb14d5a6b
2 changed files with 99 additions and 91 deletions

View file

@ -232,80 +232,6 @@ typedef float dfloat; // dequantize float
typedef float2 dfloat2; typedef float2 dfloat2;
#endif //GGML_CUDA_F16 #endif //GGML_CUDA_F16
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
file_name, line, function_name, arch);
GGML_UNUSED(arch_list);
#else
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
file_name, line, function_name, arch, arch_list);
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
}
#ifdef __CUDA_ARCH__
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
#else
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
GGML_UNUSED(a);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
#if CUDART_VERSION < 12000 #if CUDART_VERSION < 12000
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
@ -397,6 +323,80 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
} }
#endif // defined(GGML_USE_HIPBLAS) #endif // defined(GGML_USE_HIPBLAS)
#ifdef __CUDA_ARCH__
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
#else
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
file_name, line, function_name, arch);
GGML_UNUSED(arch_list);
#else
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
file_name, line, function_name, arch, arch_list);
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
GGML_UNUSED(a);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL

View file

@ -4,8 +4,14 @@
#include <cstdint> #include <cstdint>
#if FP16_MMA_AVAILABLE #if FP16_MMA_AVAILABLE
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
#else
#include <mma.h> #include <mma.h>
#endif namespace wmma = nvcuda::wmma;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#endif // FP16_MMA_AVAILABLE
#define FATTN_KQ_STRIDE 256 #define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
@ -231,11 +237,11 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@ -319,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) { for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
} }
} }
@ -333,20 +339,20 @@ static __global__ void flash_attn_ext_f16(
frag_c_KQ KQ_c[ncols/frag_n]; frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); wmma::fill_fragment(KQ_c[j], 0.0f);
} }
#pragma unroll #pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a; frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
} }
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
} }
} }
@ -452,7 +458,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16; const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync( wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k, KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded); kqar*kqs_padded);
@ -464,7 +470,7 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
} }
#pragma unroll #pragma unroll
@ -472,10 +478,10 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16; const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a; frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
} }
} }
} }
@ -487,10 +493,10 @@ static __global__ void flash_attn_ext_f16(
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync( wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major); D_padded, wmma::mem_col_major);
} }
} }
@ -863,6 +869,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return; return;
} }
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // 32x8 tensor cores are not available on AMD.
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int nwarps = 4; constexpr int nwarps = 4;
@ -885,6 +892,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} }
return; return;
} }
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
if (Q->ne[1] <= 32) { if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16; constexpr int cols_per_block = 16;