diff --git a/ggml.c b/ggml.c index a49be8858..762603c9d 100644 --- a/ggml.c +++ b/ggml.c @@ -492,6 +492,15 @@ static inline float hsum_float_4(const __m128 x) { return _mm_cvtss_f32(res); } + +// horizontally add 2x4 floats +static inline float hsum_float_2x4(const __m128 x, const __m128 y) { + __m128 res =_mm_hadd_ps(x, y); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} #endif #if __AVX__ || __AVX2__ || __AVX512F__ @@ -2141,7 +2150,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * *s = hsum_float_8(acc); #elif defined(__SSE3__) // Initialize accumulator with zeros - __m128 acc = _mm_setzero_ps(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); // Main loop for (int i = 0; i < nb; ++i) { @@ -2167,12 +2177,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m128 p0 = _mm_cvtepi32_ps(i32_0); __m128 p1 = _mm_cvtepi32_ps(i32_1); - // Apply the scale, and accumulate - acc = _mm_add_ps(_mm_mul_ps( d, p0 ), acc); - acc = _mm_add_ps(_mm_mul_ps( d, p1 ), acc); + // Apply the scale + __m128 p0_d = _mm_mul_ps( d, p0 ); + __m128 p1_d = _mm_mul_ps( d, p1 ); + + // Accumulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); } - *s = hsum_float_4(acc); + *s = hsum_float_2x4(acc_0, acc_1); #else // scalar float sumf = 0.0;