diff --git a/ggml-quants.c b/ggml-quants.c index 26f0384e7..d4524c70f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9809,6 +9809,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void #elif defined __AVX2__ + const __m256i mask = _mm256_set1_epi16(0x7); + const __m256i mone = _mm256_set1_epi16(1); + iq1m_scale_t scale; __m256 accum1 = _mm256_setzero_ps(); @@ -9850,12 +9853,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const __m256i dot3 = mul_add_epi8(delta1, q8b_1); const __m256i dot4 = mul_add_epi8(delta2, q8b_2); - const int16_t ls1 = 2*((sc[ib/2] >> 0) & 7) + 1; - const int16_t ls2 = 2*((sc[ib/2] >> 3) & 7) + 1; - const int16_t ls3 = 2*((sc[ib/2] >> 6) & 7) + 1; - const int16_t ls4 = 2*((sc[ib/2] >> 9) & 7) + 1; - const __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1)); - const __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3)); + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); + scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p2 = _mm256_madd_epi16(dot2, scale2); const __m256i p3 = _mm256_madd_epi16(dot3, scale1);