From 6adca19c94a506295121df99ba5d496432787fd2 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Fri, 24 Jan 2025 10:17:04 +0100 Subject: [PATCH 1/8] ggml-cpu: Add CPU backend support for KleidiAI library --- common/common.cpp | 2 + ggml/CMakeLists.txt | 1 + ggml/include/ggml-backend.h | 2 +- ggml/include/ggml-cpu.h | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 88 +++++- ggml/src/ggml-cpu/ggml-cpu-traits.cpp | 4 +- ggml/src/ggml-cpu/ggml-cpu-traits.h | 2 +- ggml/src/ggml-cpu/ggml-cpu.c | 33 ++- ggml/src/ggml-cpu/ggml-cpu.cpp | 30 +- .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp | 267 ++++++++++++++++++ .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.h | 18 ++ .../ggml-kleidiai/kleidiai_kernels.cpp | 165 +++++++++++ .../ggml-cpu/ggml-kleidiai/kleidiai_kernels.h | 62 ++++ include/llama.h | 2 + src/llama-model-loader.cpp | 4 +- src/llama-model-loader.h | 5 +- src/llama-model.cpp | 7 +- src/llama-quant.cpp | 2 +- src/llama.cpp | 2 +- 19 files changed, 675 insertions(+), 22 deletions(-) create mode 100644 ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp create mode 100644 ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h create mode 100644 ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp create mode 100644 ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h diff --git a/common/common.cpp b/common/common.cpp index 6dea8e3d2..044f218b4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1099,6 +1099,8 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.kv_overrides = params.kv_overrides.data(); } + mparams.n_threads = params.cpuparams.n_threads; + return mparams; } diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 185079aa4..0e892261d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -101,6 +101,7 @@ endif() option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) +option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index fc9571c82..ce66a4733 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -189,7 +189,7 @@ extern "C" { // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); // Get additional buffer types provided by the device (returns a NULL-terminated array) - typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device); + typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device, int n_threads); // Set the abort callback for the backend typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data); // Get a list of feature flags supported by the backend (returns a NULL-terminated array) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 3aa71badb..4bb10ec43 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -95,6 +95,7 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); GGML_BACKEND_API int ggml_cpu_has_sve (void); GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + GGML_BACKEND_API int ggml_cpu_has_sme (void); // other GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6b3641c42..38447b6bd 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -126,6 +126,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") check_arm_feature(i8mm "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") check_arm_feature(sve "#include \nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") + check_arm_feature(sme "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }") list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}") else() @@ -150,7 +151,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (ARM_FEATURE_RESULT) message(WARNING "Failed to get ARM features") else() - foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC) + foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) if (NOT ${feature_pos} EQUAL -1) message(STATUS "ARM feature ${feature} enabled") @@ -316,6 +317,91 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64) endif() + if (GGML_CPU_KLEIDIAI) + message(STATUS "Using KleidiAI optimized kernels if applicable") + + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + # Fetch KleidiAI sources: + include(FetchContent) + set(KLEIDIAI_COMMIT_SHA "v1.2.0") + set(KLEIDIAI_DOWNLOAD_URL "https://gitlab.arm.com/kleidi/kleidiai/-/archive/${KLEIDIAI_COMMIT_SHA}/kleidiai-${KLEIDIAI_COMMIT_SHA}.tar.gz") + set(KLEIDIAI_ARCHIVE_MD5 "cebcb660079bf15626e7bdaecd18f49c") + + if (POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + endif() + + FetchContent_Declare(KleidiAI_Download + URL ${KLEIDIAI_DOWNLOAD_URL} + DOWNLOAD_EXTRACT_TIMESTAMP NEW + URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + + FetchContent_MakeAvailable(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED) + + if (NOT KLEIDIAI_POPULATED) + message(FATAL_ERROR "KleidiAI source downloaded failed.") + endif() + + add_compile_definitions(GGML_USE_CPU_KLEIDIAI) + + # Remove kleidiai target after fetching it + if (TARGET kleidiai) + set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) + endif() + + list(APPEND GGML_CPU_SOURCES + ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp + ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp + ggml-cpu/ggml-kleidiai/ggml-kleidiai.h + ggml-cpu/ggml-kleidiai/kleidiai_kernels.h + ) + + # KleidiAI + include_directories( + ${KLEIDIAI_SRC}/ + ${KLEIDIAI_SRC}/kai/ + ${KLEIDIAI_SRC}/kai/ukernels/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) + + string(FIND ${ARCH_FLAGS} "+dotprod" DOTPROD_ENABLED) + string(FIND ${ARCH_FLAGS} "+i8mm" I8MM_ENABLED) + string(FIND ${ARCH_FLAGS} "+sme" SME_ENABLED) + + set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) + + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + + if (NOT DOTPROD_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + endif() + + if (NOT I8MM_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c) + endif() + + if (NOT SME_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) + set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + endif() + + list(APPEND GGML_CDEF_PUBLIC GGML_USE_CPU_KLEIDIAI) + set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS ${PRIVATE_ARCH_FLAGS}) + list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) + endif() + message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}") target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES}) target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp index 62a0712da..14536fe1b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp @@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {} } // namespace ggml::cpu bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type(params->nth)) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); @@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct } bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type(n_threads)) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.h b/ggml/src/ggml-cpu/ggml-cpu-traits.h index 99a6186b1..eba2d379b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-traits.h +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.h @@ -33,6 +33,6 @@ class extra_buffer_type { } // namespace ggml::cpu // implemented in ggml-cpu.cpp. -std::vector & ggml_backend_cpu_get_extra_buffers_type(); +std::vector & ggml_backend_cpu_get_extra_buffers_type(int n_threads); #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0ed92b3ff..0cf95562c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -114,7 +114,8 @@ struct ggml_arm_arch_features_type { int has_i8mm; int has_sve; int sve_cnt; -} ggml_arm_arch_features = {-1, -1, -1, -1, 0}; + int has_sme; +} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1}; #endif @@ -2389,15 +2390,20 @@ bool ggml_is_numa(void) { #define HWCAP2_I8MM (1 << 13) #endif +#if !defined(HWCAP2_SME) +#define HWCAP2_SME (1 << 23) +#endif + static void ggml_init_arm_arch_features(void) { #if defined(__linux__) && defined(__aarch64__) uint32_t hwcap = getauxval(AT_HWCAP); uint32_t hwcap2 = getauxval(AT_HWCAP2); - ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); - ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); - ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME); #if defined(__ARM_FEATURE_SVE) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); @@ -2420,6 +2426,11 @@ static void ggml_init_arm_arch_features(void) { } ggml_arm_arch_features.has_i8mm = oldp; + if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_sme = oldp; + ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #else @@ -2443,6 +2454,12 @@ static void ggml_init_arm_arch_features(void) { ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #endif + +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2) + ggml_arm_arch_features.has_sme = 1; +#else + ggml_arm_arch_features.has_sme = 0; +#endif #endif } #endif @@ -14349,6 +14366,14 @@ int ggml_cpu_get_sve_cnt(void) { #endif } +int ggml_cpu_has_sme(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME) + return ggml_arm_arch_features.has_sme; +#else + return 0; +#endif +} + void ggml_cpu_init(void) { // needed to initialize f16 tables { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 35a1c876c..399f3f0f3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -14,6 +14,10 @@ #include "ggml-cpu-hbm.h" #endif +#ifdef GGML_USE_CPU_KLEIDIAI +#include "ggml-kleidiai/ggml-kleidiai.h" +#endif + #if defined(__APPLE__) #include #include @@ -29,8 +33,8 @@ // ggml-backend interface -std::vector& ggml_backend_cpu_get_extra_buffers_type() { - static std::vector bufts = []() { +std::vector& ggml_backend_cpu_get_extra_buffers_type(int n_threads) { + static std::vector bufts = [n_threads]() { std::vector bufts; #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) @@ -39,6 +43,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } #endif +#ifdef GGML_USE_CPU_KLEIDIAI + if (ggml_backend_cpu_kleidiai_buffer_type(n_threads)) { + bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type(n_threads)); + } +#endif + #ifdef GGML_USE_CPU_AARCH64 if (ggml_backend_cpu_aarch64_buffer_type()) { bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); @@ -48,19 +58,21 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type bufts.push_back(NULL); return bufts; + + GGML_UNUSED(n_threads); }(); return bufts; } -static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) { - return ggml_backend_cpu_get_extra_buffers_type().data(); +static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device, int n_threads) { + return ggml_backend_cpu_get_extra_buffers_type(n_threads).data(); GGML_UNUSED(device); } static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type(-1)) { if (extra && extra == buft) return true; } return false; @@ -375,7 +387,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } // extra_buffer_op? - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type(-1)) { if (extra) { auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context; if (buf_extra && buf_extra->supports_op(dev, op)) { @@ -540,6 +552,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); } + if (ggml_cpu_has_sme()) { + features.push_back({ "SME", "1" }); + } if (ggml_cpu_has_riscv_v()) { features.push_back({ "RISCV_V", "1" }); } @@ -561,6 +576,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r #ifdef GGML_USE_OPENMP features.push_back({ "OPENMP", "1" }); #endif + #ifdef GGML_USE_CPU_KLEIDIAI + features.push_back({ "KLEIDIAI_REPACK", "1" }); + #endif #ifdef GGML_USE_CPU_AARCH64 features.push_back({ "AARCH64_REPACK", "1" }); #endif diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp new file mode 100644 index 000000000..504996146 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp @@ -0,0 +1,267 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#include +#elif defined(__APPLE__) +#include +#include +#include +#elif defined(_WIN32) +#include +#include +#endif + +#include "ggml-kleidiai.h" + +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "kleidiai_kernels.h" + +#include "kai_common.h" + +static const size_t k_q4_0_block_size = 32; + +struct ggml_kleidiai_context { + ggml_kleidiai_kernels * kernels; +} static ctx = { NULL }; + +static void init_kleidiai_context(int n_threads) { + static bool initialized = false; + + if (!initialized) { + GGML_ASSERT(n_threads > 0); + + initialized = true; + + cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + +#if defined(__APPLE__) + if (n_threads == 1) { + features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + } +#else + features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; +#endif + ctx.kernels = ggml_kleidiai_select_kernels(features); + } +} + +namespace ggml::cpu::kleidiai { +class tensor_traits : public ggml::cpu::tensor_traits { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + + size_t k = op->src[0]->ne[0]; + size_t m = op->src[1]->ne[1]; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + size_t bl = k_q4_0_block_size; + + size = ctx.kernels->lhs_info.packed_size(m, k, bl, mr, kr, sr); + + return true; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { + if (dst->op == GGML_OP_MUL_MAT) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + size_t bl = k_q4_0_block_size; + + const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, bl, mr, kr, sr); + + if (ith == 0) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); + void * dst_ptr = static_cast(lhs_packed + lhs_packed_offset); + + lhs_info->pack_func(m, k, bl, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); + } + + ggml_barrier(params->threadpool); + // Perform the operation + const size_t dst_stride = dst->nb[1]; + + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, k_q4_0_block_size); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + + const void * lhs_ptr = static_cast(lhs_packed + lhs_packed_offset); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + kernel->run_kernel(m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + return true; + } + return false; + } + +public: + int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(ctx.kernels); + const size_t n = tensor->ne[1]; + const size_t k = tensor->ne[0]; + size_t nr = ctx.kernels->gemm.get_nr(); + size_t kr = ctx.kernels->gemm.get_kr(); + size_t sr = ctx.kernels->gemm.get_sr(); + +#ifndef NDEBUG + const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, k_q4_0_block_size); + GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); +#endif + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, k_q4_0_block_size, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + + return 0; + } +}; + +static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { + static tensor_traits traits; + return &traits; +} +} // namespace ggml::cpu::kleidiai + +static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + GGML_ASSERT(OK == 0); + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_KLEIDIAI"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::kleidiai { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + if ( op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type(-1) && ctx.kernels + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type(-1)) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::kleidiai + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(int n_threads) { + static ggml::cpu::kleidiai::extra_buffer_type ctx; + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ &ctx, + }; + + init_kleidiai_context(n_threads); + + return &ggml_backend_cpu_buffer_type_kleidiai; +} diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h new file mode 100644 index 000000000..166c3f1a1 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml-cpu-traits.h" +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(int n_threads); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp new file mode 100644 index 000000000..97ac5ddb3 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp @@ -0,0 +1,165 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +// KleidiAI micro-kernels +#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_common.h" + +#include "kleidiai_kernels.h" + +#define NELEMS(x) sizeof(x) / sizeof(*x) +static ggml_kleidiai_kernels gemm_gemv_kernels[] = { +#if defined(__ARM_FEATURE_SME) + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + }, +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { + ggml_kleidiai_kernels * kernels = nullptr; + + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { + kernels = &gemm_gemv_kernels[i]; + break; + } + } + + return kernels; +} diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h new file mode 100644 index 000000000..0f97b46e9 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml.h" + +enum cpu_feature { + CPU_FEATURE_NONE = 0, + CPU_FEATURE_DOTPROD = 1, + CPU_FEATURE_I8MM = 2, + CPU_FEATURE_SVE = 4, + CPU_FEATURE_SME = 8 +}; +inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { + lhs = static_cast(lhs | rhs); + return lhs; +} +inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); +} + +struct kernel_info { + size_t (*get_m_step)(void); + size_t (*get_n_step)(void); + size_t (*get_mr)(void); + size_t (*get_nr)(void); + size_t (*get_kr)(void); + size_t (*get_sr)(void); + size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); + size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); + size_t (*get_dst_size)(size_t m, size_t n); + void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, + float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); +}; + +struct lhs_packing_info { + size_t (*get_offset)(size_t m_idx, size_t lhs_stride); + size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); +}; + +struct rhs_packing_info { + size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); +}; + +struct ggml_kleidiai_kernels { + kernel_info gemm; + kernel_info gemv; + lhs_packing_info lhs_info; + rhs_packing_info rhs_info; + + cpu_feature required_cpu; +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); diff --git a/include/llama.h b/include/llama.h index 3b75e7607..bb3aa8674 100644 --- a/include/llama.h +++ b/include/llama.h @@ -304,6 +304,8 @@ extern "C" { bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data + + int n_threads; }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 75073bf61..512faee18 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p) { + const struct llama_model_kv_override * param_overrides_p, + int n_threads) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -683,6 +684,7 @@ llama_model_loader::llama_model_loader( this->use_mmap = use_mmap; this->check_tensors = check_tensors; + this->n_threads = n_threads; } std::string llama_model_loader::get_arch_name() const { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index fe35404b2..49cb18a3d 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -77,6 +77,8 @@ struct llama_model_loader { llama_mmaps mappings; + int n_threads; + std::map weights_map; std::unordered_map kv_overrides; @@ -95,7 +97,8 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p); + const struct llama_model_kv_override * param_overrides_p, + int n_threads); template typename std::enable_if::value, bool>::type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 031b4c30b..199ecdcab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -247,7 +247,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> CPU extra -> GPU host -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, int n_threads) { buft_list_t buft_list; // add ACCEL buffer types @@ -268,7 +268,7 @@ static buft_list_t make_cpu_buft_list(const std::vector & de auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev, n_threads); while (extra_bufts && *extra_bufts) { buft_list.emplace_back(cpu_dev, *extra_bufts); ++extra_bufts; @@ -1264,7 +1264,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const bool use_mmap_buffer = true; // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.n_threads); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3768,6 +3768,7 @@ struct llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, }; #ifdef GGML_USE_METAL diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index fb7982655..0ebb7504f 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nthread); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/src/llama.cpp b/src/llama.cpp index e8cfe5012..179460a4f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.n_threads); ml.print_info(); From 119d3bf98663d54598e236f7cfa2903f5b2c57ec Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Thu, 30 Jan 2025 12:50:08 +0100 Subject: [PATCH 2/8] Add environmental variable GGML_KLEIDIAI_SME --- common/common.cpp | 2 -- ggml/include/ggml-backend.h | 2 +- ggml/src/ggml-cpu/CMakeLists.txt | 6 ++--- ggml/src/ggml-cpu/ggml-cpu-traits.cpp | 4 +-- ggml/src/ggml-cpu/ggml-cpu-traits.h | 2 +- ggml/src/ggml-cpu/ggml-cpu.cpp | 20 +++++++------- .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp | 26 ++++++++++--------- .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.h | 2 +- include/llama.h | 2 -- src/llama-model-loader.cpp | 4 +-- src/llama-model-loader.h | 5 +--- src/llama-model.cpp | 7 +++-- src/llama-quant.cpp | 2 +- src/llama.cpp | 2 +- 14 files changed, 38 insertions(+), 48 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 044f218b4..6dea8e3d2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1099,8 +1099,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.kv_overrides = params.kv_overrides.data(); } - mparams.n_threads = params.cpuparams.n_threads; - return mparams; } diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index ce66a4733..fc9571c82 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -189,7 +189,7 @@ extern "C" { // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); // Get additional buffer types provided by the device (returns a NULL-terminated array) - typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device, int n_threads); + typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device); // Set the abort callback for the backend typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data); // Get a list of feature flags supported by the backend (returns a NULL-terminated array) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 38447b6bd..bba18303b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -325,9 +325,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_SHA "v1.2.0") - set(KLEIDIAI_DOWNLOAD_URL "https://gitlab.arm.com/kleidi/kleidiai/-/archive/${KLEIDIAI_COMMIT_SHA}/kleidiai-${KLEIDIAI_COMMIT_SHA}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "cebcb660079bf15626e7bdaecd18f49c") + set(KLEIDIAI_COMMIT_TAG "v1.2.0") + set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") + set(KLEIDIAI_ARCHIVE_MD5 "6634fefce7357ecfee9eace2068bc68b") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp index 14536fe1b..62a0712da 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp @@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {} } // namespace ggml::cpu bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type(params->nth)) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); @@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct } bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type(n_threads)) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.h b/ggml/src/ggml-cpu/ggml-cpu-traits.h index eba2d379b..99a6186b1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-traits.h +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.h @@ -33,6 +33,6 @@ class extra_buffer_type { } // namespace ggml::cpu // implemented in ggml-cpu.cpp. -std::vector & ggml_backend_cpu_get_extra_buffers_type(int n_threads); +std::vector & ggml_backend_cpu_get_extra_buffers_type(); #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 399f3f0f3..b79d979db 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -33,8 +33,8 @@ // ggml-backend interface -std::vector& ggml_backend_cpu_get_extra_buffers_type(int n_threads) { - static std::vector bufts = [n_threads]() { +std::vector& ggml_backend_cpu_get_extra_buffers_type() { + static std::vector bufts = []() { std::vector bufts; #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) @@ -44,8 +44,8 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type #endif #ifdef GGML_USE_CPU_KLEIDIAI - if (ggml_backend_cpu_kleidiai_buffer_type(n_threads)) { - bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type(n_threads)); + if (ggml_backend_cpu_kleidiai_buffer_type()) { + bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); } #endif @@ -58,21 +58,19 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type bufts.push_back(NULL); return bufts; - - GGML_UNUSED(n_threads); }(); return bufts; } -static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device, int n_threads) { - return ggml_backend_cpu_get_extra_buffers_type(n_threads).data(); +static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) { + return ggml_backend_cpu_get_extra_buffers_type().data(); GGML_UNUSED(device); } static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type(-1)) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { if (extra && extra == buft) return true; } return false; @@ -387,7 +385,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } // extra_buffer_op? - for (auto extra : ggml_backend_cpu_get_extra_buffers_type(-1)) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { if (extra) { auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context; if (buf_extra && buf_extra->supports_op(dev, op)) { @@ -577,7 +575,7 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r features.push_back({ "OPENMP", "1" }); #endif #ifdef GGML_USE_CPU_KLEIDIAI - features.push_back({ "KLEIDIAI_REPACK", "1" }); + features.push_back({ "KLEIDIAI", "1" }); #endif #ifdef GGML_USE_CPU_AARCH64 features.push_back({ "AARCH64_REPACK", "1" }); diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp index 504996146..32eadbf49 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp @@ -34,25 +34,25 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels; } static ctx = { NULL }; -static void init_kleidiai_context(int n_threads) { +static void init_kleidiai_context(void) { static bool initialized = false; if (!initialized) { - GGML_ASSERT(n_threads > 0); - initialized = true; + const char *env_var = getenv("GGML_KLEIDIAI_SME"); + int sme_enabled = 0; cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); -#if defined(__APPLE__) - if (n_threads == 1) { + if (env_var) { + sme_enabled = atoi(env_var); + } + + if (sme_enabled != 0) { features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } -#else - features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; -#endif ctx.kernels = ggml_kleidiai_select_kernels(features); } } @@ -162,6 +162,8 @@ public: ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, k_q4_0_block_size, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); return 0; + + GGML_UNUSED(data_size); } }; @@ -223,7 +225,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { op->src[0]->type == GGML_TYPE_Q4_0 && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type(-1) && ctx.kernels + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels ) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; @@ -237,7 +239,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { if (op->op == GGML_OP_MUL_MAT) { - if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type(-1)) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } } @@ -246,7 +248,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { }; } // namespace ggml::cpu::kleidiai -ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(int n_threads) { +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { static ggml::cpu::kleidiai::extra_buffer_type ctx; static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { /* .iface = */ { @@ -261,7 +263,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(int n_threads) /* .context = */ &ctx, }; - init_kleidiai_context(n_threads); + init_kleidiai_context(); return &ggml_backend_cpu_buffer_type_kleidiai; } diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h index 166c3f1a1..aca221e8e 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h @@ -11,7 +11,7 @@ extern "C" { #endif -ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(int n_threads); +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void); #ifdef __cplusplus } diff --git a/include/llama.h b/include/llama.h index bb3aa8674..3b75e7607 100644 --- a/include/llama.h +++ b/include/llama.h @@ -304,8 +304,6 @@ extern "C" { bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data - - int n_threads; }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 512faee18..75073bf61 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -445,8 +445,7 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p, - int n_threads) { + const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -684,7 +683,6 @@ llama_model_loader::llama_model_loader( this->use_mmap = use_mmap; this->check_tensors = check_tensors; - this->n_threads = n_threads; } std::string llama_model_loader::get_arch_name() const { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 49cb18a3d..fe35404b2 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -77,8 +77,6 @@ struct llama_model_loader { llama_mmaps mappings; - int n_threads; - std::map weights_map; std::unordered_map kv_overrides; @@ -97,8 +95,7 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p, - int n_threads); + const struct llama_model_kv_override * param_overrides_p); template typename std::enable_if::value, bool>::type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 199ecdcab..031b4c30b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -247,7 +247,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> CPU extra -> GPU host -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, int n_threads) { +static buft_list_t make_cpu_buft_list(const std::vector & devices) { buft_list_t buft_list; // add ACCEL buffer types @@ -268,7 +268,7 @@ static buft_list_t make_cpu_buft_list(const std::vector & de auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev, n_threads); + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); while (extra_bufts && *extra_bufts) { buft_list.emplace_back(cpu_dev, *extra_bufts); ++extra_bufts; @@ -1264,7 +1264,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const bool use_mmap_buffer = true; // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.n_threads); + pimpl->cpu_buft_list = make_cpu_buft_list(devices); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3768,7 +3768,6 @@ struct llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, - /*.n_threads =*/ GGML_DEFAULT_N_THREADS, }; #ifdef GGML_USE_METAL diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0ebb7504f..fb7982655 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nthread); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/src/llama.cpp b/src/llama.cpp index 179460a4f..e8cfe5012 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.n_threads); + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); ml.print_info(); From f4eb1b38546bca8385c243fce1b93515d5ca707a Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Tue, 4 Feb 2025 13:38:27 +0100 Subject: [PATCH 3/8] Add support for multithread LHS conversion --- ggml/src/ggml-cpu/CMakeLists.txt | 12 +++---- .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp | 33 +++++++++++-------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index bba18303b..4b0ca0daa 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -117,7 +117,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ) if (GGML_MACHINE_SUPPORTS_${tag}) set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) - else() + elseif(NOT tag STREQUAL "sme") set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) endif() set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) @@ -325,9 +325,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.2.0") + set(KLEIDIAI_COMMIT_TAG "v1.3.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "6634fefce7357ecfee9eace2068bc68b") + set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -370,9 +370,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) - string(FIND ${ARCH_FLAGS} "+dotprod" DOTPROD_ENABLED) - string(FIND ${ARCH_FLAGS} "+i8mm" I8MM_ENABLED) - string(FIND ${ARCH_FLAGS} "+sme" SME_ENABLED) + string(FIND "${ARCH_FLAGS}" "+dotprod" DOTPROD_ENABLED) + string(FIND "${ARCH_FLAGS}" "+i8mm" I8MM_ENABLED) + string(FIND "${ARCH_FLAGS}" "+sme" SME_ENABLED) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp index 32eadbf49..77fe8e86b 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp @@ -114,30 +114,37 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = kernel->get_sr(); size_t bl = k_q4_0_block_size; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, bl, mr, kr, sr); + // Calculate number of columns to be processed per thread + const size_t num_m_per_thread = kai_roundup(m, nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } - if (ith == 0) { + if(m_start < m) { // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); - void * dst_ptr = static_cast(lhs_packed + lhs_packed_offset); + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, bl, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - lhs_info->pack_func(m, k, bl, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); + lhs_info->pack_func(m_to_process, k, bl, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); } ggml_barrier(params->threadpool); - // Perform the operation - const size_t dst_stride = dst->nb[1]; + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, k_q4_0_block_size, mr, kr, sr); const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, k_q4_0_block_size); const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - - const void * lhs_ptr = static_cast(lhs_packed + lhs_packed_offset); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); kernel->run_kernel(m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); return true; } return false; From 3e08f37b08b488f8d60f7e33b5f8aeb3c3c4f1c5 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Wed, 5 Feb 2025 15:20:07 +0100 Subject: [PATCH 4/8] Switch kernel selection order to dotprod and i8mm --- ggml/src/ggml-cpu/CMakeLists.txt | 4 +- .../ggml-kleidiai/kleidiai_kernels.cpp | 86 +++++++++---------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 4b0ca0daa..6d8fce504 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -398,8 +398,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() list(APPEND GGML_CDEF_PUBLIC GGML_USE_CPU_KLEIDIAI) - set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS ${PRIVATE_ARCH_FLAGS}) - list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) + set_source_files_properties("${GGML_KLEIDIAI_SOURCES}" PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") + list(APPEND GGML_CPU_SOURCES "${GGML_KLEIDIAI_SOURCES}") endif() message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}") diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp index 97ac5ddb3..fbb44cf17 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp @@ -63,49 +63,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .required_cpu = */ CPU_FEATURE_SME, }, #endif -#if defined(__ARM_FEATURE_MATMUL_INT8) - { - /* i8mm GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - }, - /* i8mm GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, - }, -#endif #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ @@ -149,6 +106,49 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .required_cpu = */ CPU_FEATURE_DOTPROD, }, #endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif }; ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { From 9edd10737bdb5df3f84edcf5923dc6f3f18a86e9 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Fri, 7 Feb 2025 16:09:58 +0100 Subject: [PATCH 5/8] updates for review comments --- ggml/src/ggml-cpu/CMakeLists.txt | 23 +++-- .../ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp | 12 ++- .../ggml-kleidiai/kleidiai_kernels.cpp | 89 +++++++++++++++++++ 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6d8fce504..10c7eb847 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) function(check_arm_feature tag code) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}") - check_cxx_source_runs( - "${code}" - GGML_MACHINE_SUPPORTS_${tag} - ) + check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag}) if (GGML_MACHINE_SUPPORTS_${tag}) set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) - elseif(NOT tag STREQUAL "sme") - set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) + else() + set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}") + check_cxx_source_compiles("${code}" GGML_MACHINE_SUPPORTS_no${tag}) + if (GGML_MACHINE_SUPPORTS_no${tag}) + set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) + endif() endif() set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) endfunction() @@ -370,9 +371,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) - string(FIND "${ARCH_FLAGS}" "+dotprod" DOTPROD_ENABLED) - string(FIND "${ARCH_FLAGS}" "+i8mm" I8MM_ENABLED) - string(FIND "${ARCH_FLAGS}" "+sme" SME_ENABLED) + set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") + if (NOT ARCH_FLAGS_TEMP) + string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}") + endif() + string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp index 77fe8e86b..8b689880e 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp @@ -23,6 +23,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-threading.h" #include "kleidiai_kernels.h" @@ -35,6 +36,8 @@ struct ggml_kleidiai_context { } static ctx = { NULL }; static void init_kleidiai_context(void) { + + ggml_critical_section_start(); static bool initialized = false; if (!initialized) { @@ -55,6 +58,12 @@ static void init_kleidiai_context(void) { } ctx.kernels = ggml_kleidiai_select_kernels(features); } + ggml_critical_section_end(); +} + +static inline int ggml_ne(const ggml_tensor * tensor, int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + return tensor->ne[dim]; } namespace ggml::cpu::kleidiai { @@ -237,7 +246,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } - if (op->src[1]->type == GGML_TYPE_F32) { + if (op->src[1]->type == GGML_TYPE_F32 && + ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { return true; } } diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp index fbb44cf17..4d100c9a9 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp @@ -63,6 +63,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .required_cpu = */ CPU_FEATURE_SME, }, #endif +#if defined(__APPLE__) #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ @@ -149,6 +150,94 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, }, #endif +#else +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +#endif }; ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { From e04880fa54314b0448123d115da678b06adf51a1 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Fri, 7 Feb 2025 17:42:54 +0100 Subject: [PATCH 6/8] More updates for review comments --- ggml/src/ggml-cpu/CMakeLists.txt | 2 +- ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 10c7eb847..043b12f05 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -116,7 +116,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) else() set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}") - check_cxx_source_compiles("${code}" GGML_MACHINE_SUPPORTS_no${tag}) + check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag}) if (GGML_MACHINE_SUPPORTS_no${tag}) set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) endif() diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp index 8b689880e..abd2e9ffc 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp +++ b/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp @@ -61,7 +61,7 @@ static void init_kleidiai_context(void) { ggml_critical_section_end(); } -static inline int ggml_ne(const ggml_tensor * tensor, int dim) { +static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); return tensor->ne[dim]; } From ca0c8b6b73601c44a7d2d951251fee7e3254d965 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 10 Feb 2025 15:01:27 +0100 Subject: [PATCH 7/8] Reorganize and rename KleidiAI files --- ggml/src/ggml-cpu/CMakeLists.txt | 8 +++--- ggml/src/ggml-cpu/ggml-cpu.cpp | 2 +- .../kernels.cpp} | 2 +- .../kleidiai_kernels.h => kleidiai/kernels.h} | 4 --- .../kleidiai.cpp} | 25 +++++++++---------- .../ggml-kleidiai.h => kleidiai/kleidiai.h} | 1 - 6 files changed, 18 insertions(+), 24 deletions(-) rename ggml/src/ggml-cpu/{ggml-kleidiai/kleidiai_kernels.cpp => kleidiai/kernels.cpp} (99%) rename ggml/src/ggml-cpu/{ggml-kleidiai/kleidiai_kernels.h => kleidiai/kernels.h} (98%) rename ggml/src/ggml-cpu/{ggml-kleidiai/ggml-kleidiai.cpp => kleidiai/kleidiai.cpp} (92%) rename ggml/src/ggml-cpu/{ggml-kleidiai/ggml-kleidiai.h => kleidiai/kleidiai.h} (94%) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 043b12f05..858082046 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -356,10 +356,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() list(APPEND GGML_CPU_SOURCES - ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp - ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp - ggml-cpu/ggml-kleidiai/ggml-kleidiai.h - ggml-cpu/ggml-kleidiai/kleidiai_kernels.h + ggml-cpu/kleidiai/kleidiai.cpp + ggml-cpu/kleidiai/kernels.cpp + ggml-cpu/kleidiai/kleidiai.h + ggml-cpu/kleidiai/kernels.h ) # KleidiAI diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index b79d979db..f93f68a65 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -15,7 +15,7 @@ #endif #ifdef GGML_USE_CPU_KLEIDIAI -#include "ggml-kleidiai/ggml-kleidiai.h" +#include "kleidiai/kleidiai.h" #endif #if defined(__APPLE__) diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp similarity index 99% rename from ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp rename to ggml/src/ggml-cpu/kleidiai/kernels.cpp index 4d100c9a9..76ca62bb1 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -16,7 +16,7 @@ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_common.h" -#include "kleidiai_kernels.h" +#include "kernels.h" #define NELEMS(x) sizeof(x) / sizeof(*x) static ggml_kleidiai_kernels gemm_gemv_kernels[] = { diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h similarity index 98% rename from ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h rename to ggml/src/ggml-cpu/kleidiai/kernels.h index 0f97b46e9..f606eb2ef 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -2,10 +2,6 @@ // SPDX-License-Identifier: MIT // -#pragma once - -#include "ggml.h" - enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp similarity index 92% rename from ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp rename to ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index abd2e9ffc..c9235cd63 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -18,18 +18,19 @@ #include #endif -#include "ggml-kleidiai.h" +#include "kleidiai.h" #include "ggml-cpu.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" -#include "kleidiai_kernels.h" +#include "kernels.h" #include "kai_common.h" -static const size_t k_q4_0_block_size = 32; +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels; @@ -78,9 +79,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size_t bl = k_q4_0_block_size; - size = ctx.kernels->lhs_info.packed_size(m, k, bl, mr, kr, sr); + size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); return true; } @@ -121,7 +121,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size_t bl = k_q4_0_block_size; // Calculate number of columns to be processed per thread const size_t num_m_per_thread = kai_roundup(m, nth) / nth; @@ -135,24 +134,24 @@ class tensor_traits : public ggml::cpu::tensor_traits { // Transform LHS const size_t src_stride = src1->nb[1]; const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, bl, mr, kr, sr); + const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - lhs_info->pack_func(m_to_process, k, bl, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); + lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); } ggml_barrier(params->threadpool); // Perform the operation const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, k_q4_0_block_size, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, k_q4_0_block_size); + const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - kernel->run_kernel(m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr, + kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); return true; } @@ -169,13 +168,13 @@ public: size_t sr = ctx.kernels->gemm.get_sr(); #ifndef NDEBUG - const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, k_q4_0_block_size); + const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); #endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, k_q4_0_block_size, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); return 0; diff --git a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h similarity index 94% rename from ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h rename to ggml/src/ggml-cpu/kleidiai/kleidiai.h index aca221e8e..6fd6b257a 100644 --- a/ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.h +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.h @@ -5,7 +5,6 @@ #pragma once #include "ggml-cpu-traits.h" -#include "ggml.h" #ifdef __cplusplus extern "C" { From 02315a8dbe74afeda8951ae727d8553e53c7eb55 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 10 Feb 2025 15:32:06 +0100 Subject: [PATCH 8/8] Move ggml-cpu-traits.h to source file --- ggml/src/ggml-cpu/kleidiai/kernels.h | 2 ++ ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 1 + ggml/src/ggml-cpu/kleidiai/kleidiai.h | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index f606eb2ef..2ffe97eb4 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -2,6 +2,8 @@ // SPDX-License-Identifier: MIT // +#pragma once + enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index c9235cd63..2c7413ff5 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -24,6 +24,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" +#include "ggml-cpu-traits.h" #include "kernels.h" diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h index 6fd6b257a..38eac58f7 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.h +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.h @@ -4,7 +4,7 @@ #pragma once -#include "ggml-cpu-traits.h" +#include "ggml-alloc.h" #ifdef __cplusplus extern "C" {