From 607b9c7373e25d71929e888c08b6a6bafa31b326 Mon Sep 17 00:00:00 2001 From: 3ooabkhxtn <3ooabkhxtn@local> Date: Fri, 12 May 2023 08:04:54 +0000 Subject: [PATCH] - Split multiplication and addition to make it easier for the compiler to optimise - Accumulate two acc instead of one llama_print_timings: load time = 3137.95 ms llama_print_timings: sample time = 132.54 ms / 128 runs ( 1.04 ms per token) llama_print_timings: prompt eval time = 2943.22 ms / 8 tokens ( 367.90 ms per token) llama_print_timings: eval time = 59539.50 ms / 127 runs ( 468.81 ms per token) llama_print_timings: total time = 62843.23 ms --- ggml.c | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index a49be8858..762603c9d 100644 --- a/ggml.c +++ b/ggml.c @@ -492,6 +492,15 @@ static inline float hsum_float_4(const __m128 x) { return _mm_cvtss_f32(res); } + +// horizontally add 2x4 floats +static inline float hsum_float_2x4(const __m128 x, const __m128 y) { + __m128 res =_mm_hadd_ps(x, y); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} #endif #if __AVX__ || __AVX2__ || __AVX512F__ @@ -2141,7 +2150,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * *s = hsum_float_8(acc); #elif defined(__SSE3__) // Initialize accumulator with zeros - __m128 acc = _mm_setzero_ps(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); // Main loop for (int i = 0; i < nb; ++i) { @@ -2167,12 +2177,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m128 p0 = _mm_cvtepi32_ps(i32_0); __m128 p1 = _mm_cvtepi32_ps(i32_1); - // Apply the scale, and accumulate - acc = _mm_add_ps(_mm_mul_ps( d, p0 ), acc); - acc = _mm_add_ps(_mm_mul_ps( d, p1 ), acc); + // Apply the scale + __m128 p0_d = _mm_mul_ps( d, p0 ); + __m128 p1_d = _mm_mul_ps( d, p1 ); + + // Accumulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); } - *s = hsum_float_4(acc); + *s = hsum_float_2x4(acc_0, acc_1); #else // scalar float sumf = 0.0;