Merge 02315a8dbe into d7b31a9d84
				
					
				
			This commit is contained in:
		
						commit
						782075ae8e
					
				
					 9 changed files with 761 additions and 10 deletions
				
			
		|  | @ -102,6 +102,7 @@ endif() | |||
| 
 | ||||
| option(GGML_CPU_HBM          "ggml: use memkind for CPU HBM" OFF) | ||||
| option(GGML_CPU_AARCH64      "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) | ||||
| option(GGML_CPU_KLEIDIAI     "ggml: use KleidiAI optimized kernels if applicable" OFF) | ||||
| option(GGML_AVX              "ggml: enable AVX"              ${INS_ENB}) | ||||
| option(GGML_AVX_VNNI         "ggml: enable AVX-VNNI"         OFF) | ||||
| option(GGML_AVX2             "ggml: enable AVX2"             ${INS_ENB}) | ||||
|  |  | |||
|  | @ -95,6 +95,7 @@ extern "C" { | |||
|     GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); | ||||
|     GGML_BACKEND_API int ggml_cpu_has_sve        (void); | ||||
|     GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
 | ||||
|     GGML_BACKEND_API int ggml_cpu_has_sme        (void); | ||||
|     // other
 | ||||
|     GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void); | ||||
|     GGML_BACKEND_API int ggml_cpu_has_vsx        (void); | ||||
|  |  | |||
|  | @ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) | |||
|                 function(check_arm_feature tag code) | ||||
|                     set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) | ||||
|                     set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}") | ||||
|                     check_cxx_source_runs( | ||||
|                         "${code}" | ||||
|                         GGML_MACHINE_SUPPORTS_${tag} | ||||
|                     ) | ||||
|                     check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag}) | ||||
|                     if (GGML_MACHINE_SUPPORTS_${tag}) | ||||
|                         set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) | ||||
|                     else() | ||||
|                         set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) | ||||
|                         set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}") | ||||
|                         check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag}) | ||||
|                         if (GGML_MACHINE_SUPPORTS_no${tag}) | ||||
|                             set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) | ||||
|                         endif() | ||||
|                     endif() | ||||
|                     set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) | ||||
|                 endfunction() | ||||
|  | @ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) | |||
|                 check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") | ||||
|                 check_arm_feature(i8mm    "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") | ||||
|                 check_arm_feature(sve     "#include <arm_sve.h>\nint main()  { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") | ||||
|                 check_arm_feature(sme     "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }") | ||||
| 
 | ||||
|                 list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}") | ||||
|             else() | ||||
|  | @ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) | |||
|             if (ARM_FEATURE_RESULT) | ||||
|                 message(WARNING "Failed to get ARM features") | ||||
|             else() | ||||
|                 foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC) | ||||
|                 foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) | ||||
|                     string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) | ||||
|                     if (NOT ${feature_pos} EQUAL -1) | ||||
|                         message(STATUS "ARM feature ${feature} enabled") | ||||
|  | @ -316,6 +318,95 @@ function(ggml_add_cpu_backend_variant_impl tag_name) | |||
|         target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64) | ||||
|     endif() | ||||
| 
 | ||||
|     if (GGML_CPU_KLEIDIAI) | ||||
|         message(STATUS "Using KleidiAI optimized kernels if applicable") | ||||
| 
 | ||||
|         # Disable the KleidiAI tests | ||||
|         set(KLEIDIAI_BUILD_TESTS  OFF) | ||||
| 
 | ||||
|         # Fetch KleidiAI sources: | ||||
|         include(FetchContent) | ||||
|         set(KLEIDIAI_COMMIT_TAG "v1.3.0") | ||||
|         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") | ||||
|         set(KLEIDIAI_ARCHIVE_MD5  "060bd2dc64642b091f461cc8dd7426d9") | ||||
| 
 | ||||
|         if (POLICY CMP0135) | ||||
|             cmake_policy(SET CMP0135 NEW) | ||||
|         endif() | ||||
| 
 | ||||
|         FetchContent_Declare(KleidiAI_Download | ||||
|             URL ${KLEIDIAI_DOWNLOAD_URL} | ||||
|             DOWNLOAD_EXTRACT_TIMESTAMP NEW | ||||
|             URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) | ||||
| 
 | ||||
|         FetchContent_MakeAvailable(KleidiAI_Download) | ||||
|         FetchContent_GetProperties(KleidiAI_Download | ||||
|             SOURCE_DIR  KLEIDIAI_SRC | ||||
|             POPULATED   KLEIDIAI_POPULATED) | ||||
| 
 | ||||
|         if (NOT KLEIDIAI_POPULATED) | ||||
|             message(FATAL_ERROR "KleidiAI source downloaded failed.") | ||||
|         endif() | ||||
| 
 | ||||
|         add_compile_definitions(GGML_USE_CPU_KLEIDIAI) | ||||
| 
 | ||||
|         # Remove kleidiai target after fetching it | ||||
|         if (TARGET kleidiai) | ||||
|             set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) | ||||
|         endif() | ||||
| 
 | ||||
|         list(APPEND GGML_CPU_SOURCES | ||||
|             ggml-cpu/kleidiai/kleidiai.cpp | ||||
|             ggml-cpu/kleidiai/kernels.cpp | ||||
|             ggml-cpu/kleidiai/kleidiai.h | ||||
|             ggml-cpu/kleidiai/kernels.h | ||||
|             ) | ||||
| 
 | ||||
|         # KleidiAI | ||||
|         include_directories( | ||||
|             ${KLEIDIAI_SRC}/ | ||||
|             ${KLEIDIAI_SRC}/kai/ | ||||
|             ${KLEIDIAI_SRC}/kai/ukernels/ | ||||
|             ${KLEIDIAI_SRC}/kai/ukernels/matmul/ | ||||
|             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ | ||||
|             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) | ||||
| 
 | ||||
|         set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") | ||||
|         if (NOT ARCH_FLAGS_TEMP) | ||||
|             string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}") | ||||
|         endif() | ||||
|         string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) | ||||
|         string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) | ||||
|         string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) | ||||
| 
 | ||||
|         set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) | ||||
| 
 | ||||
|         list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) | ||||
|         list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) | ||||
|         list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) | ||||
|         list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) | ||||
| 
 | ||||
|         if (NOT DOTPROD_ENABLED MATCHES -1) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) | ||||
|         endif() | ||||
| 
 | ||||
|         if (NOT I8MM_ENABLED MATCHES -1) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c) | ||||
|         endif() | ||||
| 
 | ||||
|         if (NOT SME_ENABLED MATCHES -1) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) | ||||
|             list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) | ||||
|             set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") | ||||
|         endif() | ||||
| 
 | ||||
|         list(APPEND GGML_CDEF_PUBLIC GGML_USE_CPU_KLEIDIAI) | ||||
|         set_source_files_properties("${GGML_KLEIDIAI_SOURCES}" PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") | ||||
|         list(APPEND GGML_CPU_SOURCES "${GGML_KLEIDIAI_SOURCES}") | ||||
|     endif() | ||||
| 
 | ||||
|     message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}") | ||||
|     target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES}) | ||||
|     target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) | ||||
|  |  | |||
|  | @ -114,7 +114,8 @@ struct ggml_arm_arch_features_type { | |||
|     int has_i8mm; | ||||
|     int has_sve; | ||||
|     int sve_cnt; | ||||
| } ggml_arm_arch_features = {-1, -1, -1, -1, 0}; | ||||
|     int has_sme; | ||||
| } ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1}; | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
|  | @ -2383,15 +2384,20 @@ bool ggml_is_numa(void) { | |||
| #define HWCAP2_I8MM (1 << 13) | ||||
| #endif | ||||
| 
 | ||||
| #if !defined(HWCAP2_SME) | ||||
| #define HWCAP2_SME (1 << 23) | ||||
| #endif | ||||
| 
 | ||||
| static void ggml_init_arm_arch_features(void) { | ||||
| #if defined(__linux__) && defined(__aarch64__) | ||||
|     uint32_t hwcap = getauxval(AT_HWCAP); | ||||
|     uint32_t hwcap2 = getauxval(AT_HWCAP2); | ||||
| 
 | ||||
|     ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); | ||||
|     ggml_arm_arch_features.has_neon    = !!(hwcap & HWCAP_ASIMD); | ||||
|     ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); | ||||
|     ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); | ||||
|     ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE); | ||||
|     ggml_arm_arch_features.has_i8mm    = !!(hwcap2 & HWCAP2_I8MM); | ||||
|     ggml_arm_arch_features.has_sve     = !!(hwcap & HWCAP_SVE); | ||||
|     ggml_arm_arch_features.has_sme     = !!(hwcap2 & HWCAP2_SME); | ||||
| 
 | ||||
| #if defined(__ARM_FEATURE_SVE) | ||||
|     ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); | ||||
|  | @ -2414,6 +2420,11 @@ static void ggml_init_arm_arch_features(void) { | |||
|     } | ||||
|     ggml_arm_arch_features.has_i8mm = oldp; | ||||
| 
 | ||||
|     if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) { | ||||
|         oldp = 0; | ||||
|     } | ||||
|     ggml_arm_arch_features.has_sme = oldp; | ||||
| 
 | ||||
|     ggml_arm_arch_features.has_sve = 0; | ||||
|     ggml_arm_arch_features.sve_cnt = 0; | ||||
| #else | ||||
|  | @ -2437,6 +2448,12 @@ static void ggml_init_arm_arch_features(void) { | |||
|     ggml_arm_arch_features.has_sve = 0; | ||||
|     ggml_arm_arch_features.sve_cnt = 0; | ||||
| #endif | ||||
| 
 | ||||
| #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2) | ||||
|     ggml_arm_arch_features.has_sme = 1; | ||||
| #else | ||||
|     ggml_arm_arch_features.has_sme = 0; | ||||
| #endif | ||||
| #endif | ||||
| } | ||||
| #endif | ||||
|  | @ -14347,6 +14364,14 @@ int ggml_cpu_get_sve_cnt(void) { | |||
| #endif | ||||
| } | ||||
| 
 | ||||
| int ggml_cpu_has_sme(void) { | ||||
| #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME) | ||||
|     return ggml_arm_arch_features.has_sme; | ||||
| #else | ||||
|     return 0; | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| void ggml_cpu_init(void) { | ||||
|     // needed to initialize f16 tables
 | ||||
|     { | ||||
|  |  | |||
|  | @ -14,6 +14,10 @@ | |||
| #include "ggml-cpu-hbm.h" | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CPU_KLEIDIAI | ||||
| #include "kleidiai/kleidiai.h" | ||||
| #endif | ||||
| 
 | ||||
| #if defined(__APPLE__) | ||||
| #include <sys/types.h> | ||||
| #include <sys/sysctl.h> | ||||
|  | @ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type | |||
|         } | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CPU_KLEIDIAI | ||||
|         if (ggml_backend_cpu_kleidiai_buffer_type()) { | ||||
|             bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); | ||||
|         } | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CPU_AARCH64 | ||||
|         if (ggml_backend_cpu_aarch64_buffer_type()) { | ||||
|             bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); | ||||
|  | @ -541,6 +551,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r | |||
|             static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); | ||||
|             features.push_back({ "SVE_CNT", sve_cnt.c_str() }); | ||||
|         } | ||||
|         if (ggml_cpu_has_sme()) { | ||||
|             features.push_back({ "SME", "1" }); | ||||
|         } | ||||
|         if (ggml_cpu_has_riscv_v()) { | ||||
|             features.push_back({ "RISCV_V", "1" }); | ||||
|         } | ||||
|  | @ -562,6 +575,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r | |||
|     #ifdef GGML_USE_OPENMP | ||||
|         features.push_back({ "OPENMP", "1" }); | ||||
|     #endif | ||||
|     #ifdef GGML_USE_CPU_KLEIDIAI | ||||
|         features.push_back({ "KLEIDIAI", "1" }); | ||||
|     #endif | ||||
|     #ifdef GGML_USE_CPU_AARCH64 | ||||
|         features.push_back({ "AARCH64_REPACK", "1" }); | ||||
|     #endif | ||||
|  |  | |||
							
								
								
									
										254
									
								
								ggml/src/ggml-cpu/kleidiai/kernels.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								ggml/src/ggml-cpu/kleidiai/kernels.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,254 @@ | |||
| // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
 | ||||
| // SPDX-License-Identifier: MIT
 | ||||
| //
 | ||||
| 
 | ||||
| // KleidiAI micro-kernels
 | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" | ||||
| #include "kai_lhs_quant_pack_qsi8d32p_f32.h" | ||||
| #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" | ||||
| #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" | ||||
| #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||||
| #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" | ||||
| #include "kai_common.h" | ||||
| 
 | ||||
| #include "kernels.h" | ||||
| 
 | ||||
| #define NELEMS(x) sizeof(x) / sizeof(*x) | ||||
| static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
| #if defined(__ARM_FEATURE_SME) | ||||
|     { | ||||
|         /* SME GEMM */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|         }, | ||||
|         /* SME GEMV */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|         }, | ||||
|         /* .lhs_info = */ { | ||||
|             /* .get_offset        = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size       = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .pack_func         = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | ||||
|             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_SME, | ||||
|     }, | ||||
| #endif | ||||
| #if defined(__APPLE__) | ||||
| #if defined(__ARM_FEATURE_DOTPROD) | ||||
|     { | ||||
|         /* DOTPROD GEMM */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|         }, | ||||
|         /* DOTPROD GEMV */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|         }, | ||||
|         /* .lhs_info = */ { | ||||
|             /* .get_offset        = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size       = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func         = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD, | ||||
|     }, | ||||
| #endif | ||||
| #if defined(__ARM_FEATURE_MATMUL_INT8) | ||||
|     { | ||||
|         /* i8mm GEMM */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|         }, | ||||
|         /* i8mm GEMV */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|         }, | ||||
|         /* .lhs_info = */ { | ||||
|             /* .get_offset        = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .packed_size       = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .pack_func         = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | ||||
|     }, | ||||
| #endif | ||||
| #else | ||||
| #if defined(__ARM_FEATURE_MATMUL_INT8) | ||||
|     { | ||||
|         /* i8mm GEMM */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|         }, | ||||
|         /* i8mm GEMV */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|         }, | ||||
|         /* .lhs_info = */ { | ||||
|             /* .get_offset        = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .packed_size       = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|         /* .pack_func         = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | ||||
|     }, | ||||
| #endif | ||||
| #if defined(__ARM_FEATURE_DOTPROD) | ||||
|     { | ||||
|         /* DOTPROD GEMM */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|         }, | ||||
|         /* DOTPROD GEMV */ | ||||
|         /* .kern_info = */ { | ||||
|             /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|         }, | ||||
|         /* .lhs_info = */ { | ||||
|             /* .get_offset        = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size       = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func         = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD, | ||||
|     }, | ||||
| #endif | ||||
| #endif | ||||
| }; | ||||
| 
 | ||||
| ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { | ||||
|     ggml_kleidiai_kernels * kernels = nullptr; | ||||
| 
 | ||||
|     for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { | ||||
|         if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { | ||||
|             kernels = &gemm_gemv_kernels[i]; | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return kernels; | ||||
| } | ||||
							
								
								
									
										60
									
								
								ggml/src/ggml-cpu/kleidiai/kernels.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								ggml/src/ggml-cpu/kleidiai/kernels.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,60 @@ | |||
| // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
 | ||||
| // SPDX-License-Identifier: MIT
 | ||||
| //
 | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| enum cpu_feature { | ||||
|     CPU_FEATURE_NONE    = 0, | ||||
|     CPU_FEATURE_DOTPROD = 1, | ||||
|     CPU_FEATURE_I8MM    = 2, | ||||
|     CPU_FEATURE_SVE     = 4, | ||||
|     CPU_FEATURE_SME     = 8 | ||||
| }; | ||||
| inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { | ||||
|     lhs = static_cast<cpu_feature>(lhs | rhs); | ||||
|     return lhs; | ||||
| } | ||||
| inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) { | ||||
|     return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs)); | ||||
| } | ||||
| 
 | ||||
| struct kernel_info { | ||||
|     size_t (*get_m_step)(void); | ||||
|     size_t (*get_n_step)(void); | ||||
|     size_t (*get_mr)(void); | ||||
|     size_t (*get_nr)(void); | ||||
|     size_t (*get_kr)(void); | ||||
|     size_t (*get_sr)(void); | ||||
|     size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); | ||||
|     size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); | ||||
|     size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); | ||||
|     size_t (*get_dst_size)(size_t m, size_t n); | ||||
|     void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, | ||||
|                          float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); | ||||
| }; | ||||
| 
 | ||||
| struct lhs_packing_info { | ||||
|     size_t (*get_offset)(size_t m_idx, size_t lhs_stride); | ||||
|     size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); | ||||
|     size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); | ||||
|     void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||||
|                       size_t lhs_stride, void* lhs_packed); | ||||
| }; | ||||
| 
 | ||||
| struct rhs_packing_info { | ||||
|     size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); | ||||
|     void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||||
|                       const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); | ||||
| }; | ||||
| 
 | ||||
| struct ggml_kleidiai_kernels { | ||||
|     kernel_info gemm; | ||||
|     kernel_info gemv; | ||||
|     lhs_packing_info lhs_info; | ||||
|     rhs_packing_info rhs_info; | ||||
| 
 | ||||
|     cpu_feature required_cpu; | ||||
| }; | ||||
| 
 | ||||
| ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); | ||||
							
								
								
									
										286
									
								
								ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										286
									
								
								ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,286 @@ | |||
| // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
 | ||||
| // SPDX-License-Identifier: MIT
 | ||||
| //
 | ||||
| #include <arm_neon.h> | ||||
| #include <assert.h> | ||||
| #include <cfloat> | ||||
| #include <stdint.h> | ||||
| #include <string.h> | ||||
| #if defined(__linux__) | ||||
| #include <asm/hwcap.h> | ||||
| #include <sys/auxv.h> | ||||
| #elif defined(__APPLE__) | ||||
| #include <string_view> | ||||
| #include <sys/sysctl.h> | ||||
| #include <sys/types.h> | ||||
| #elif defined(_WIN32) | ||||
| #include <windows.h> | ||||
| #include <excpt.h> | ||||
| #endif | ||||
| 
 | ||||
| #include "kleidiai.h" | ||||
| 
 | ||||
| #include "ggml-cpu.h" | ||||
| #include "ggml-impl.h" | ||||
| #include "ggml-backend-impl.h" | ||||
| #include "ggml-threading.h" | ||||
| #include "ggml-cpu-traits.h" | ||||
| 
 | ||||
| #include "kernels.h" | ||||
| 
 | ||||
| #include "kai_common.h" | ||||
| 
 | ||||
| #define GGML_COMMON_DECL_CPP | ||||
| #include "ggml-common.h" | ||||
| 
 | ||||
| struct ggml_kleidiai_context { | ||||
|     ggml_kleidiai_kernels * kernels; | ||||
| } static ctx = { NULL }; | ||||
| 
 | ||||
| static void init_kleidiai_context(void) { | ||||
| 
 | ||||
|     ggml_critical_section_start(); | ||||
|     static bool initialized = false; | ||||
| 
 | ||||
|     if (!initialized) { | ||||
|         initialized = true; | ||||
|         const char *env_var = getenv("GGML_KLEIDIAI_SME"); | ||||
|         int sme_enabled = 0; | ||||
| 
 | ||||
|         cpu_feature features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | | ||||
|                                 (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) | | ||||
|                                 (ggml_cpu_has_sve()         ? CPU_FEATURE_SVE     : CPU_FEATURE_NONE); | ||||
| 
 | ||||
|         if (env_var) { | ||||
|             sme_enabled = atoi(env_var); | ||||
|         } | ||||
| 
 | ||||
|         if (sme_enabled != 0) { | ||||
|             features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; | ||||
|         } | ||||
|         ctx.kernels = ggml_kleidiai_select_kernels(features); | ||||
|     } | ||||
|     ggml_critical_section_end(); | ||||
| } | ||||
| 
 | ||||
| static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { | ||||
|     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); | ||||
|     return tensor->ne[dim]; | ||||
| } | ||||
| 
 | ||||
| namespace ggml::cpu::kleidiai { | ||||
| class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { | ||||
|         GGML_ASSERT(ctx.kernels); | ||||
|         kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; | ||||
| 
 | ||||
|         size_t k = op->src[0]->ne[0]; | ||||
|         size_t m = op->src[1]->ne[1]; | ||||
| 
 | ||||
|         size_t mr = kernel->get_mr(); | ||||
|         size_t kr = kernel->get_kr(); | ||||
|         size_t sr = kernel->get_sr(); | ||||
| 
 | ||||
|         size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); | ||||
| 
 | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
|     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { | ||||
|         if (dst->op == GGML_OP_MUL_MAT) { | ||||
|             const ggml_tensor * src0 = dst->src[0]; | ||||
|             const ggml_tensor * src1 = dst->src[1]; | ||||
| 
 | ||||
|             GGML_TENSOR_BINARY_OP_LOCALS | ||||
| 
 | ||||
|             GGML_ASSERT(ctx.kernels); | ||||
|             kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; | ||||
|             lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; | ||||
| 
 | ||||
|             GGML_ASSERT(kernel); | ||||
| 
 | ||||
|             const int ith = params->ith; | ||||
|             const int nth = params->nth; | ||||
| 
 | ||||
|             const size_t k = ne00; | ||||
|             const size_t m = ne11; | ||||
|             const size_t n = ne01; | ||||
| 
 | ||||
|             const size_t n_step = kernel->get_n_step(); | ||||
|             const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); | ||||
|             const size_t n_start = ith * num_n_per_thread; | ||||
| 
 | ||||
|             size_t n_to_process = num_n_per_thread; | ||||
|             if ((n_start + n_to_process) > n) { | ||||
|                 n_to_process = n - n_start; | ||||
|             } | ||||
| 
 | ||||
|             const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data); | ||||
|             uint8_t * lhs_packed       = (uint8_t*)params->wdata; | ||||
|             const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data); | ||||
| 
 | ||||
|             size_t mr = kernel->get_mr(); | ||||
|             size_t kr = kernel->get_kr(); | ||||
|             size_t sr = kernel->get_sr(); | ||||
| 
 | ||||
|             // Calculate number of columns to be processed per thread
 | ||||
|             const size_t num_m_per_thread = kai_roundup(m, nth) / nth; | ||||
|             const size_t m_start = ith * num_m_per_thread; | ||||
|             size_t m_to_process = num_m_per_thread; | ||||
|             if ((m_start + m_to_process) > m) { | ||||
|                 m_to_process = m - m_start; | ||||
|             } | ||||
| 
 | ||||
|             if(m_start < m) { | ||||
|                 // Transform LHS
 | ||||
|                 const size_t src_stride        = src1->nb[1]; | ||||
|                 const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); | ||||
|                 const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); | ||||
|                 void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset); | ||||
| 
 | ||||
|                 lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); | ||||
|             } | ||||
| 
 | ||||
|             ggml_barrier(params->threadpool); | ||||
| 
 | ||||
|             // Perform the operation
 | ||||
|             const size_t dst_stride        = dst->nb[1]; | ||||
|             const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); | ||||
|             const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); | ||||
|             const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride); | ||||
|             const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset); | ||||
|             const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset); | ||||
|             float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset); | ||||
| 
 | ||||
|             kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, | ||||
|                                dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); | ||||
|             return true; | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
| public: | ||||
|     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { | ||||
|         GGML_ASSERT(ctx.kernels); | ||||
|         const size_t n = tensor->ne[1]; | ||||
|         const size_t k = tensor->ne[0]; | ||||
|         size_t nr      = ctx.kernels->gemm.get_nr(); | ||||
|         size_t kr      = ctx.kernels->gemm.get_kr(); | ||||
|         size_t sr      = ctx.kernels->gemm.get_sr(); | ||||
| 
 | ||||
| #ifndef NDEBUG | ||||
|         const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); | ||||
|         GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); | ||||
| #endif | ||||
|         struct kai_rhs_pack_qs4cxs1s0_param params; | ||||
|         params.lhs_zero_point = 1; | ||||
|         params.rhs_zero_point = 8; | ||||
|         ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); | ||||
| 
 | ||||
|         return 0; | ||||
| 
 | ||||
|         GGML_UNUSED(data_size); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { | ||||
|     static tensor_traits traits; | ||||
|     return &traits; | ||||
| } | ||||
| }  // namespace ggml::cpu::kleidiai
 | ||||
| 
 | ||||
| static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { | ||||
|     tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); | ||||
| 
 | ||||
|     GGML_UNUSED(buffer); | ||||
| } | ||||
| 
 | ||||
| static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, | ||||
|                                                        const void * data, size_t offset, size_t size) { | ||||
|     GGML_ASSERT(offset == 0); | ||||
|     GGML_ASSERT(size == ggml_nbytes(tensor)); | ||||
| 
 | ||||
|     auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra; | ||||
|     auto OK            = tensor_traits->repack(tensor, data, size); | ||||
| 
 | ||||
|     GGML_ASSERT(OK == 0); | ||||
|     GGML_UNUSED(buffer); | ||||
| } | ||||
| 
 | ||||
| static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { | ||||
|     return "CPU_KLEIDIAI"; | ||||
| 
 | ||||
|     GGML_UNUSED(buft); | ||||
| } | ||||
| 
 | ||||
| static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { | ||||
|     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); | ||||
| 
 | ||||
|     if (buffer == nullptr) { | ||||
|         return nullptr; | ||||
|     } | ||||
| 
 | ||||
|     buffer->buft              = buft; | ||||
|     buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor; | ||||
|     buffer->iface.set_tensor  = ggml_backend_cpu_kleidiai_buffer_set_tensor; | ||||
|     buffer->iface.get_tensor  = nullptr; | ||||
|     buffer->iface.cpy_tensor  = nullptr; | ||||
|     return buffer; | ||||
| } | ||||
| 
 | ||||
| static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { | ||||
|     return TENSOR_ALIGNMENT; | ||||
| 
 | ||||
|     GGML_UNUSED(buft); | ||||
| } | ||||
| 
 | ||||
| namespace ggml::cpu::kleidiai { | ||||
| class extra_buffer_type : ggml::cpu::extra_buffer_type { | ||||
|     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { | ||||
|         if (    op->op == GGML_OP_MUL_MAT && | ||||
|                 op->src[0]->type == GGML_TYPE_Q4_0 && | ||||
|                 op->src[0]->buffer && | ||||
|                 (ggml_n_dims(op->src[0]) == 2) && | ||||
|                 op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels | ||||
|                 ) { | ||||
|             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { | ||||
|                 return false; | ||||
|             } | ||||
|             if (op->src[1]->type == GGML_TYPE_F32 && | ||||
|                 ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { | ||||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { | ||||
|         if (op->op == GGML_OP_MUL_MAT) { | ||||
|             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { | ||||
|                 return (ggml::cpu::tensor_traits *) op->src[0]->extra; | ||||
|             } | ||||
|         } | ||||
|         return nullptr; | ||||
|     } | ||||
| }; | ||||
| }  // namespace ggml::cpu::kleidiai
 | ||||
| 
 | ||||
| ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { | ||||
|     static ggml::cpu::kleidiai::extra_buffer_type ctx; | ||||
|     static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { | ||||
|         /* .iface    = */ { | ||||
|                            /* .get_name         = */ ggml_backend_cpu_kleidiai_buffer_type_get_name, | ||||
|                            /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, | ||||
|                            /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, | ||||
|                            /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
 | ||||
|                            /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
 | ||||
|                            /* .is_host          = */ nullptr, | ||||
|                            }, | ||||
|         /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), | ||||
|         /* .context = */ &ctx, | ||||
|     }; | ||||
| 
 | ||||
|     init_kleidiai_context(); | ||||
| 
 | ||||
|     return &ggml_backend_cpu_buffer_type_kleidiai; | ||||
| } | ||||
							
								
								
									
										17
									
								
								ggml/src/ggml-cpu/kleidiai/kleidiai.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								ggml/src/ggml-cpu/kleidiai/kleidiai.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
 | ||||
| // SPDX-License-Identifier: MIT
 | ||||
| //
 | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include "ggml-alloc.h" | ||||
| 
 | ||||
| #ifdef  __cplusplus | ||||
| extern "C" { | ||||
| #endif | ||||
| 
 | ||||
| ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void); | ||||
| 
 | ||||
| #ifdef  __cplusplus | ||||
| } | ||||
| #endif | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue