diff --git a/ggml-quants.c b/ggml-quants.c index 76ea60d6f..768a27deb 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10798,6 +10798,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); const __m128i idx_mask = _mm_set1_epi32(256); typedef union { @@ -10831,17 +10833,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); idx.vec[3] = idx.vec[2]; - // AVX has no sllv so we have to do this - for (int j = 0; j < 2; ++j) { - for (int k = 0; k < 8; ++k) { - idx.index[j*8+k] <<= 8 - k; - } - } - - idx.vec[0] = _mm_and_si128(idx.vec[0], idx_mask); - idx.vec[1] = _mm_and_si128(idx.vec[1], idx_mask); - idx.vec[2] = _mm_and_si128(idx.vec[2], idx_mask); - idx.vec[3] = _mm_and_si128(idx.vec[3], idx_mask); + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));