iq2xs: small AVX2 imrovement

This commit is contained in:
Iwan Kawrakow 2024-01-29 08:52:08 +02:00
parent 76f7befaa1
commit f9d22dab25

View file

@ -8461,16 +8461,17 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
const __m256i m511 = _mm256_set1_epi16(511); const __m256i m511 = _mm256_set1_epi16(511);
const __m256i m127 = _mm256_set1_epi16(127); const __m256i m127 = _mm256_set1_epi16(127);
const __m256i mxf = _mm256_set1_epi16(0xf); const __m256i mxf = _mm256_set1_epi16(0xf);
const __m256i mone16 = _mm256_set1_epi16(1); const __m256i mone = _mm256_set1_epi8(1);
const __m256i mone8 = _mm256_set1_epi8(1);
static const uint8_t k_bit_counts[32] = { static const uint8_t k_bit_helper[32] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
}; };
static const char block_sign_shuffle_mask_bytes[64] = { static const char block_sign_shuffle_mask_1[32] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
};
static const char block_sign_shuffle_mask_2[32] = {
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
}; };
@ -8479,13 +8480,10 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
}; };
const __m256i bit_counts_table = _mm256_loadu_si256((const __m256i*)k_bit_counts); const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
const __m256i block1_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+ 0)); const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
const __m256i block2_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+32)); const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
//const __m256i block1_sign_shuffle = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
//const __m256i block2_sign_shuffle = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
uint64_t aux64; uint64_t aux64;
@ -8507,17 +8505,20 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
__m256i sumi1 = _mm256_setzero_si256(); __m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
aux_gindex = _mm256_and_si256(q2_data, m511); aux_gindex = _mm256_and_si256(q2_data, m511);
const __m256i partial_sign_bits = _mm256_and_si256(_mm256_srli_epi16(q2_data, 9), m127); const __m256i partial_sign_bits = _mm256_and_si256(_mm256_srli_epi16(q2_data, 9), m127);
const __m256i bit_counts1 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(partial_sign_bits, mxf)); const __m256i odd_bits1 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(partial_sign_bits, mxf));
const __m256i bit_counts2 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf)); const __m256i odd_bits2 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf));
const __m256i bit_counts = _mm256_add_epi8(bit_counts1, bit_counts2); const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_xor_si256(odd_bits1, odd_bits2));
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_slli_epi16(_mm256_and_si256(bit_counts, mone16), 7));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
@ -8532,21 +8533,22 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l); const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h); const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
__m256i signs = _mm256_shuffle_epi8(full_signs_1, block1_sign_shuffle); __m256i signs;
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone8)); const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
signs = _mm256_shuffle_epi8(full_signs_1, block2_sign_shuffle); signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone8)); const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
signs = _mm256_shuffle_epi8(full_signs_2, block1_sign_shuffle); signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone8)); const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
signs = _mm256_shuffle_epi8(full_signs_2, block2_sign_shuffle); signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone8)); const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);