diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 0002ab399..2e6c16160 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4213,9 +4213,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumf = hsum_float_8(acc); #elif defined(__AVX__) - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); + __m256 accum = _mm256_setzero_ps(); for (; ib + 1 < nb; ib += 2) { const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); @@ -4228,19 +4226,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); - const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); - accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_bsrli_si128(p_1, 8)), _mm_cvtepi16_epi32(p_1)))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_bsrli_si128(p_2, 8)), _mm_cvtepi16_epi32(p_2)))), accum2); + __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); + p_1 = _mm_add_epi32(_mm_cvtepi16_epi32(_mm_bsrli_si128(p_1, 8)), _mm_cvtepi16_epi32(p_1)); + __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); + p_2 = _mm_add_epi32(_mm_cvtepi16_epi32(_mm_bsrli_si128(p_2, 8)), _mm_cvtepi16_epi32(p_2)); + + const __m256 deltas = _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d)), + _mm_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d))); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1))), accum); } - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + sumf = hsum_float_8(accum); #elif defined(__SSSE3__) // set constants const __m128i lowMask = _mm_set1_epi8(0xF);