From 2a4e792ed10fe71502a659beb061e48159e58c44 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Wed, 18 Dec 2024 23:19:36 -0800 Subject: [PATCH] Fixes for clang AVX VNNI --- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 5 ++++- ggml/src/ggml-cpu/ggml-cpu-quants.c | 6 +++++- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index a51d1a6c5..23831e560 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { } static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); return _mm256_dpbusd_epi32(zero, ax, sy); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_avx_epi32(zero, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 634c5fa11..8479cf925 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index f80a72781..a50432319 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -993,8 +993,10 @@ class tinyBLAS_Q0_AVX { inline __m256 updot(__m256i u, __m256i s) { __m256i res; -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#elif defined(__AVXVNNI__) + res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif