diff --git a/CMakeLists.txt b/CMakeLists.txt index 824d9f2cf..f7be1e54b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -421,10 +421,16 @@ if (NOT MSVC) endif() endif() -if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") +if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") message(STATUS "ARM detected") if (MSVC) # TODO: arm msvc? + # x86 add_compile_options(/arch:AVX2) + add_compile_definitions(__ARM_NEON) + add_compile_definitions(__ARM_FEATURE_FMA) + add_compile_definitions(__ARM_FEATURE_DOTPROD) + #add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + add_compile_definitions(__aarch64__) # MSVC _M_ARM64 else() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero diff --git a/ggml.c b/ggml.c index 44c43b424..057902c1f 100644 --- a/ggml.c +++ b/ggml.c @@ -272,7 +272,7 @@ typedef double ggml_float; // 16-bit float // on Arm, we use __fp16 // on x86, we use uint16_t -#ifdef __ARM_NEON +#if defined(__ARM_NEON) && !defined(_MSC_VER) // if YCM cannot find , make a symbolic link to it, for example: // diff --git a/ggml.h b/ggml.h index 3a946dbdc..a26ae69b3 100644 --- a/ggml.h +++ b/ggml.h @@ -255,7 +255,7 @@ extern "C" { #endif -#ifdef __ARM_NEON +#if defined(__ARM_NEON) && !defined(_MSC_VER) // we use the built-in 16-bit float type typedef __fp16 ggml_fp16_t; #else diff --git a/k_quants.c b/k_quants.c index 6348fce6b..005de0c17 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2526,7 +2526,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri memcpy(utmp, x[i].scales, 12); - const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; +#ifndef _MSC_VER + uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; +#else + uint32x2_t mins8; + mins8.n64_u32[0] = utmp[1] & kmask1; + mins8.n64_u32[1] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); +#endif utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[0] &= kmask1;