From dca0deb3d866197d4013fb57cff5ba1302635bc6 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:21:21 -0400 Subject: [PATCH] avx bf16 vec dot --- ggml/src/ggml.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0d99b0791..f7fa6cf73 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2245,8 +2245,12 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t sumf += (ggml_float)_mm512_reduce_add_ps(c2); #undef LOAD -#elif defined(__AVX2__) +#elif defined(__AVX2__) || defined(__AVX__) +#if defined(__AVX2__) #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) +#else +#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1)) +#endif __m256 c1 = _mm256_setzero_ps(); __m256 c2 = _mm256_setzero_ps(); __m256 c3 = _mm256_setzero_ps();