From 708540712d09d36ee06761dbc471821b2b20d266 Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Sat, 22 Apr 2023 10:07:17 +0200 Subject: [PATCH] Q8_0: unbreak AVX --- ggml.c | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 72b392fdb..a65750f70 100644 --- a/ggml.c +++ b/ggml.c @@ -459,6 +459,14 @@ static inline float hsum_float_8(const __m256 x) { return _mm_cvtss_f32(res); } +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + // horizontally add 8 int32_t static inline int hsum_i32_8(const __m256i a) { const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); @@ -1381,7 +1389,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].s1 = d * sum1; } #elif defined(__AVX2__) || defined(__AVX__) - // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for (int i = 0; i < nb; i++) { // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); @@ -1428,7 +1435,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1)); y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3)); @@ -1459,8 +1465,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // Compute the sum of the quants and set y[i].s const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + y[i].s0 = d * hsum_i32_4(s0); const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0)); + y[i].s1 = d * hsum_i32_4(s1); // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 );