From 073b4b8ba5e46748672a6ef2ee3867baf1950657 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Mon, 5 Jun 2023 09:27:07 +0800 Subject: [PATCH] fix(avx): workaround for missing _mm256_setr_m128i in GCC < 8 --- ggml.c | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index 00bbee503..25753d5da 100644 --- a/ggml.c +++ b/ggml.c @@ -474,6 +474,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { @@ -533,7 +535,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) { static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); const __m256i lowMask = _mm256_set1_epi8( 0xF ); return _mm256_and_si256(lowMask, bytes); } @@ -606,7 +608,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) { bytesh = _mm_or_si128(bytesh, bit_mask); bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return _mm256_set_m128i(bytesh, bytesl); + return MM256_SET_M128I(bytesh, bytesl); } // Unpack 32 4-bit fields into 32 bytes @@ -619,7 +621,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) const __m128i lowMask = _mm_set1_epi8(0xF); tmpl = _mm_and_si128(lowMask, tmpl); tmph = _mm_and_si128(lowMask, tmph); - return _mm256_set_m128i(tmph, tmpl); + return MM256_SET_M128I(tmph, tmpl); } // add int16_t pairwise and return as float vector @@ -627,7 +629,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { const __m128i ones = _mm_set1_epi16(1); const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); return _mm256_cvtepi32_ps(summed_pairs); } @@ -2290,7 +2292,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const __m128i i32_1 = mul_sum_i8_pairs(bx, by); // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1)); + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); // Apply the scale, and accumulate acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); @@ -2766,7 +2768,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * __m128i bxh = _mm256_extractf128_si256(bx, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = _mm256_set_m128i(bxh, bxl); + bx = MM256_SET_M128I(bxh, bxl); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3022,7 +3024,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * __m128i bxh = _mm256_extractf128_si256(bx, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = _mm256_set_m128i(bxh, bxl); + bx = MM256_SET_M128I(bxh, bxl); const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);