better toolchain compability

This commit is contained in:
Reinforce-II 2024-05-23 11:58:19 +08:00
parent 9a166331e0
commit c812542f86
4 changed files with 49 additions and 13 deletions

View file

@ -79,6 +79,7 @@ option(LLAMA_AVX512 "llama: enable AVX512"
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF) option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_AMX "llama: enable AMX" OFF)
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
# in MSVC F16C is implied with AVX2/AVX512 # in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC) if (NOT MSVC)
@ -1072,6 +1073,14 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>) add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>) add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif() endif()
if (LLAMA_AMX)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
endif()
elseif (LLAMA_AVX2) elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2) list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX) elseif (LLAMA_AVX)
@ -1106,6 +1115,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_BF16) if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16) list(APPEND ARCH_FLAGS -mavx512bf16)
endif() endif()
if (LLAMA_AMX)
list(APPEND ARCH_FLAGS -mavx512vl -mavx512dq)
list(APPEND ARCH_FLAGS -mamx-tile -mamx-int8 -mamx-bf16)
endif()
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")

47
ggml.c
View file

@ -27,9 +27,7 @@
#if defined(__gnu_linux__) #if defined(__gnu_linux__)
#include <syscall.h> #include <syscall.h>
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) #if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023 #define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18 #define XFEATURE_XTILEDATA 18
#endif #endif
#endif #endif
@ -1904,24 +1902,40 @@ static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void
} }
} }
} }
typedef struct __tile_config
{
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved_0[14];
uint16_t colsb[8];
uint8_t reserved_1[16];
uint8_t rows[8];
uint8_t reserved_2[8];
} __tile_config_t;
#endif #endif
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) #if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (nrc == AMX_TILE_MN) { if (nrc == AMX_TILE_MN) {
assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0); assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0);
__tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4}; // 0: zt, 1: yt, 2: xt
__tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4}; __tile_config_t cfg = {
__tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)}; .palette_id = 1,
__tile_zero(&tilezt); .start_row = 0,
.colsb = {AMX_TILE_MN*sizeof(float), AMX_TILE_K*4, AMX_TILE_MN*4, 0,},
.rows = {AMX_TILE_MN, AMX_TILE_K, AMX_TILE_MN, 0,},
};
_tile_loadconfig(&cfg);
_tile_zero(0);
for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) { for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) {
ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)]; ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)];
ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K); ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K);
__tile_loadd(&tileyt, y + i, by); _tile_loadd(1, y + i, by);
__tile_loadd(&tilext, axt, AMX_TILE_MN*4); _tile_loadd(2, axt, AMX_TILE_MN*4);
__tile_dpbf16ps(&tilezt, tileyt, tilext); _tile_dpbf16ps(0, 1, 2);
} }
__tile_stored(s, bs*sizeof(float), tilezt); _tile_stored(0, s, bs*sizeof(float));
return; return;
} }
#endif #endif
@ -19485,10 +19499,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
set_numa_thread_affinity(state->ith); set_numa_thread_affinity(state->ith);
#if defined(__gnu_linux__) #if defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__gnu_linux__)
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) // refer to https://www.kernel.org/doc/Documentation/x86/xstate.rst
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
#endif
#endif #endif
int node_n = -1; int node_n = -1;
@ -23045,4 +23058,12 @@ int ggml_cpu_has_matmul_int8(void) {
#endif #endif
} }
int ggml_cpu_has_amx(void) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
return 1;
#else
return 0;
#endif
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

1
ggml.h
View file

@ -2421,6 +2421,7 @@ extern "C" {
GGML_API int ggml_cpu_has_sycl (void); GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_vsx (void);
GGML_API int ggml_cpu_has_matmul_int8(void); GGML_API int ggml_cpu_has_matmul_int8(void);
GGML_API int ggml_cpu_has_amx (void);
// //
// Internal types and functions exposed for tests and benchmarks // Internal types and functions exposed for tests and benchmarks

View file

@ -18355,6 +18355,7 @@ const char * llama_print_system_info(void) {
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | "; s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
s += "AMX = " + std::to_string(ggml_cpu_has_amx()) + " | ";
#ifdef GGML_USE_LLAMAFILE #ifdef GGML_USE_LLAMAFILE
s += "LLAMAFILE = 1 | "; s += "LLAMAFILE = 1 | ";
#else #else