diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 166823158..c0156e878 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -99,21 +99,36 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV}) elseif (APPLE) if (GGML_NATIVE) - set(MARCH_FLAGS "-march=armv8.2a") + set(USER_PROVIDED_MARCH FALSE) + foreach(flag_var IN ITEMS CMAKE_C_FLAGS CMAKE_CXX_FLAGS CMAKE_REQUIRED_FLAGS) + if ("${${flag_var}}" MATCHES "-march=[a-zA-Z0-9+._-]+") + set(USER_PROVIDED_MARCH TRUE) + break() + endif() + endforeach() - check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD) - if (GGML_COMPILER_SUPPORT_DOTPROD) - set(MARCH_FLAGS "${MARCH_FLAGS}+dotprod") - add_compile_definitions(__ARM_FEATURE_DOTPROD) + if (NOT USER_PROVIDED_MARCH) + set(MARCH_FLAGS "-march=armv8.2a") + + check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD) + if (GGML_COMPILER_SUPPORT_DOTPROD) + set(MARCH_FLAGS "${MARCH_FLAGS}+dotprod") + add_compile_definitions(__ARM_FEATURE_DOTPROD) + endif () + + set(TEST_I8MM_FLAGS "-march=armv8.2a+i8mm") + + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${TEST_I8MM_FLAGS}") + check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8) + if (GGML_COMPILER_SUPPORT_MATMUL_INT8) + set(MARCH_FLAGS "${MARCH_FLAGS}+i8mm") + add_compile_definitions(__ARM_FEATURE_MATMUL_INT8) + endif () + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + list(APPEND ARCH_FLAGS "${MARCH_FLAGS}") endif () - - check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8) - if (GGML_COMPILER_SUPPORT_MATMUL_INT8) - add_compile_definitions(__ARM_FEATURE_MATMUL_INT8) - set(MARCH_FLAGS "${MARCH_FLAGS}+i8mm") - endif () - - list(APPEND ARCH_FLAGS "${MARCH_FLAGS}") endif () else() check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)