Fix flashattn

This commit is contained in:
Jerome 2024-04-20 11:06:03 -04:00
parent 7eb14d5a6b
commit 3e560c8665
4 changed files with 11 additions and 17 deletions

View file

@ -1175,7 +1175,7 @@ add_library(ggml OBJECT
)
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
target_compile_features (ggml PUBLIC c_std_11) # don't bump
target_compile_features (ggml PUBLIC cxx_std_17) # don't bump
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})

View file

@ -364,16 +364,11 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
GGML_UNUSED(a);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
@ -399,8 +394,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA
// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

View file

@ -7,6 +7,12 @@
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
inline __device__ __half2 __hmax2(__half2 x, __half2 y) {
return __half2_raw{
{{__hmax(__half2_raw(x).x, __half2_raw(y).x),
__hmax(__half2_raw(x).y, __half2_raw(y).y)}}
};
}
#else
#include <mma.h>
namespace wmma = nvcuda::wmma;
@ -339,7 +345,7 @@ static __global__ void flash_attn_ext_f16(
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(KQ_c[j], 0.0f);
wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f});
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
@ -470,7 +476,7 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f});
}
#pragma unroll

View file

@ -15473,13 +15473,6 @@ struct llama_context * llama_new_context_with_model(
cparams.flash_attn = false;
}
#ifdef GGML_USE_HIPBLAS
if (cparams.flash_attn) {
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
cparams.flash_attn = false;
}
#endif
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}