diff --git a/ggml-quants.c b/ggml-quants.c index 8c945697d..958b2bca9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -8459,8 +8459,6 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const __m128i m4 = _mm_set1_epi8(0xf); const __m128i m1 = _mm_set1_epi8(1); const __m256i m511 = _mm256_set1_epi16(511); - const __m256i m127 = _mm256_set1_epi16(127); - const __m256i mxf = _mm256_set1_epi16(0xf); const __m256i mone = _mm256_set1_epi8(1); static const uint8_t k_bit_helper[32] = { @@ -8509,10 +8507,12 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; aux_gindex = _mm256_and_si256(q2_data, m511); - const __m256i partial_sign_bits = _mm256_and_si256(_mm256_srli_epi16(q2_data, 9), m127); - const __m256i odd_bits1 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(partial_sign_bits, mxf)); - const __m256i odd_bits2 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf)); - const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_xor_si256(odd_bits1, odd_bits2)); + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;