diff --git a/ggml-quants.c b/ggml-quants.c index aa511e740..8c945697d 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -8461,16 +8461,17 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const __m256i m511 = _mm256_set1_epi16(511); const __m256i m127 = _mm256_set1_epi16(127); const __m256i mxf = _mm256_set1_epi16(0xf); - const __m256i mone16 = _mm256_set1_epi16(1); - const __m256i mone8 = _mm256_set1_epi8(1); + const __m256i mone = _mm256_set1_epi8(1); - static const uint8_t k_bit_counts[32] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, }; - static const char block_sign_shuffle_mask_bytes[64] = { + static const char block_sign_shuffle_mask_1[32] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, }; @@ -8479,13 +8480,10 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, }; - const __m256i bit_counts_table = _mm256_loadu_si256((const __m256i*)k_bit_counts); + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); - const __m256i block1_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+ 0)); - const __m256i block2_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+32)); - - //const __m256i block1_sign_shuffle = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000); - //const __m256i block2_sign_shuffle = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); uint64_t aux64; @@ -8507,17 +8505,20 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; aux_gindex = _mm256_and_si256(q2_data, m511); + const __m256i partial_sign_bits = _mm256_and_si256(_mm256_srli_epi16(q2_data, 9), m127); - const __m256i bit_counts1 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(partial_sign_bits, mxf)); - const __m256i bit_counts2 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf)); - const __m256i bit_counts = _mm256_add_epi8(bit_counts1, bit_counts2); - const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_slli_epi16(_mm256_and_si256(bit_counts, mone16), 7)); + const __m256i odd_bits1 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(partial_sign_bits, mxf)); + const __m256i odd_bits2 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf)); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_xor_si256(odd_bits1, odd_bits2)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], @@ -8532,21 +8533,22 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l); const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h); - __m256i signs = _mm256_shuffle_epi8(full_signs_1, block1_sign_shuffle); + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone8)); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - signs = _mm256_shuffle_epi8(full_signs_1, block2_sign_shuffle); + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone8)); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - signs = _mm256_shuffle_epi8(full_signs_2, block1_sign_shuffle); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone8)); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - signs = _mm256_shuffle_epi8(full_signs_2, block2_sign_shuffle); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone8)); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);