From 3e560c8665d4ea627be920a26da6d83811fde3b4 Mon Sep 17 00:00:00 2001 From: Jerome Date: Sat, 20 Apr 2024 11:06:03 -0400 Subject: [PATCH] Fix flashattn --- CMakeLists.txt | 2 +- ggml-cuda/common.cuh | 9 ++------- ggml-cuda/fattn.cu | 10 ++++++++-- llama.cpp | 7 ------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 477c5b57c..d00524407 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1175,7 +1175,7 @@ add_library(ggml OBJECT ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) -target_compile_features (ggml PUBLIC c_std_11) # don't bump +target_compile_features (ggml PUBLIC cxx_std_17) # don't bump target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index e742cd3c2..0069b9e52 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -364,16 +364,11 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); } return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } static __device__ __forceinline__ float warp_reduce_max(float x) { @@ -399,8 +394,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL - -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1e5f4410b..a605ed566 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -7,6 +7,12 @@ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #include namespace wmma = rocwmma; +inline __device__ __half2 __hmax2(__half2 x, __half2 y) { + return __half2_raw{ + {{__hmax(__half2_raw(x).x, __half2_raw(y).x), + __hmax(__half2_raw(x).y, __half2_raw(y).y)}} + }; +} #else #include namespace wmma = nvcuda::wmma; @@ -339,7 +345,7 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f}); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { @@ -470,7 +476,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f}); } #pragma unroll diff --git a/llama.cpp b/llama.cpp index 18d6297ce..04c5d7706 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15473,13 +15473,6 @@ struct llama_context * llama_new_context_with_model( cparams.flash_attn = false; } -#ifdef GGML_USE_HIPBLAS - if (cparams.flash_attn) { - LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); - cparams.flash_attn = false; - } -#endif - if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); }