iq2xs: small AVX2 imrovement
This commit is contained in:
parent
76f7befaa1
commit
f9d22dab25
1 changed files with 26 additions and 24 deletions
|
@ -8461,16 +8461,17 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||
const __m256i m511 = _mm256_set1_epi16(511);
|
||||
const __m256i m127 = _mm256_set1_epi16(127);
|
||||
const __m256i mxf = _mm256_set1_epi16(0xf);
|
||||
const __m256i mone16 = _mm256_set1_epi16(1);
|
||||
const __m256i mone8 = _mm256_set1_epi8(1);
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
|
||||
static const uint8_t k_bit_counts[32] = {
|
||||
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
|
||||
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
|
||||
static const uint8_t k_bit_helper[32] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
static const char block_sign_shuffle_mask_bytes[64] = {
|
||||
static const char block_sign_shuffle_mask_1[32] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
||||
};
|
||||
static const char block_sign_shuffle_mask_2[32] = {
|
||||
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
||||
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
||||
};
|
||||
|
@ -8479,13 +8480,10 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
};
|
||||
|
||||
const __m256i bit_counts_table = _mm256_loadu_si256((const __m256i*)k_bit_counts);
|
||||
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
||||
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
||||
const __m256i block1_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+ 0));
|
||||
const __m256i block2_sign_shuffle = _mm256_loadu_si256((const __m256i*)(block_sign_shuffle_mask_bytes+32));
|
||||
|
||||
//const __m256i block1_sign_shuffle = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
|
||||
//const __m256i block2_sign_shuffle = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
|
||||
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
||||
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
||||
|
||||
uint64_t aux64;
|
||||
|
||||
|
@ -8507,17 +8505,20 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||
__m256i sumi1 = _mm256_setzero_si256();
|
||||
__m256i sumi2 = _mm256_setzero_si256();
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
||||
|
||||
const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
|
||||
aux_gindex = _mm256_and_si256(q2_data, m511);
|
||||
|
||||
const __m256i partial_sign_bits = _mm256_and_si256(_mm256_srli_epi16(q2_data, 9), m127);
|
||||
const __m256i bit_counts1 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(partial_sign_bits, mxf));
|
||||
const __m256i bit_counts2 = _mm256_shuffle_epi8(bit_counts_table, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf));
|
||||
const __m256i bit_counts = _mm256_add_epi8(bit_counts1, bit_counts2);
|
||||
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_slli_epi16(_mm256_and_si256(bit_counts, mone16), 7));
|
||||
const __m256i odd_bits1 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(partial_sign_bits, mxf));
|
||||
const __m256i odd_bits2 = _mm256_shuffle_epi8(bit_helper, _mm256_and_si256(_mm256_srli_epi16(partial_sign_bits, 4), mxf));
|
||||
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, _mm256_xor_si256(odd_bits1, odd_bits2));
|
||||
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
|
||||
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
|
||||
iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
|
||||
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
|
||||
|
@ -8532,21 +8533,22 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
|
||||
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
|
||||
|
||||
__m256i signs = _mm256_shuffle_epi8(full_signs_1, block1_sign_shuffle);
|
||||
__m256i signs;
|
||||
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone8));
|
||||
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
||||
|
||||
signs = _mm256_shuffle_epi8(full_signs_1, block2_sign_shuffle);
|
||||
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone8));
|
||||
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
||||
|
||||
signs = _mm256_shuffle_epi8(full_signs_2, block1_sign_shuffle);
|
||||
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone8));
|
||||
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
|
||||
|
||||
signs = _mm256_shuffle_epi8(full_signs_2, block2_sign_shuffle);
|
||||
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone8));
|
||||
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
|
||||
|
||||
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
||||
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue