Fix flashattn
This commit is contained in:
parent
7eb14d5a6b
commit
3e560c8665
4 changed files with 11 additions and 17 deletions
|
@ -1175,7 +1175,7 @@ add_library(ggml OBJECT
|
||||||
)
|
)
|
||||||
|
|
||||||
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
|
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
|
||||||
target_compile_features (ggml PUBLIC c_std_11) # don't bump
|
target_compile_features (ggml PUBLIC cxx_std_17) # don't bump
|
||||||
|
|
||||||
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||||
|
|
||||||
|
|
|
@ -364,16 +364,11 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
}
|
}
|
||||||
return a;
|
return a;
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
|
@ -399,8 +394,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
|
||||||
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
||||||
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
||||||
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
|
||||||
// TODO: move to ggml-common.h
|
// TODO: move to ggml-common.h
|
||||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||||
|
|
|
@ -7,6 +7,12 @@
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#include <rocwmma/rocwmma.hpp>
|
#include <rocwmma/rocwmma.hpp>
|
||||||
namespace wmma = rocwmma;
|
namespace wmma = rocwmma;
|
||||||
|
inline __device__ __half2 __hmax2(__half2 x, __half2 y) {
|
||||||
|
return __half2_raw{
|
||||||
|
{{__hmax(__half2_raw(x).x, __half2_raw(y).x),
|
||||||
|
__hmax(__half2_raw(x).y, __half2_raw(y).y)}}
|
||||||
|
};
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
namespace wmma = nvcuda::wmma;
|
namespace wmma = nvcuda::wmma;
|
||||||
|
@ -339,7 +345,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
frag_c_KQ KQ_c[ncols/frag_n];
|
frag_c_KQ KQ_c[ncols/frag_n];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
wmma::fill_fragment(KQ_c[j], 0.0f);
|
wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f});
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||||
|
@ -470,7 +476,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
|
@ -15473,13 +15473,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
cparams.flash_attn = false;
|
cparams.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_HIPBLAS
|
|
||||||
if (cparams.flash_attn) {
|
|
||||||
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
|
|
||||||
cparams.flash_attn = false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue