iq2_xs
This commit is contained in:
parent
dcfee06594
commit
eccc609efa
1 changed files with 160 additions and 1 deletions
161
ggml-quants.c
161
ggml-quants.c
|
@ -8974,7 +8974,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
|||
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||
const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
||||
const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]]);
|
||||
const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
|
||||
const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
||||
const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
|
||||
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||
|
@ -9362,6 +9362,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
#elif defined(__AVX__)
|
||||
const __m128i mone = _mm_set1_epi8(1);
|
||||
static const char block_sign_shuffle_mask_1[32] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
||||
};
|
||||
static const char block_sign_shuffle_mask_2[32] = {
|
||||
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
||||
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
||||
};
|
||||
static const uint8_t bit_selector_mask_bytes[32] = {
|
||||
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
};
|
||||
|
||||
const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
|
||||
const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
|
||||
const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
|
||||
const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
|
||||
const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
|
||||
const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
|
||||
|
||||
static const uint8_t k_bit_helper[32] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
|
||||
const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
|
||||
const __m128i m511 = _mm_set1_epi16(511);
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
const __m128i m1 = _mm_set1_epi8(1);
|
||||
|
||||
uint64_t aux64;
|
||||
|
||||
// somewhat hacky, but gives a significant boost in performance
|
||||
__m256i aux_gindex;
|
||||
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
||||
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint16_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
memcpy(&aux64, x[i].scales, 8);
|
||||
__m128i stmp = _mm_set1_epi64x(aux64);
|
||||
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
|
||||
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
|
||||
|
||||
__m128i sumi1_0 = _mm_setzero_si128();
|
||||
__m128i sumi1_1 = _mm_setzero_si128();
|
||||
__m128i sumi2_0 = _mm_setzero_si128();
|
||||
__m128i sumi2_1 = _mm_setzero_si128();
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
||||
|
||||
const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
|
||||
aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
|
||||
|
||||
const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
|
||||
const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
|
||||
const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
|
||||
const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
|
||||
const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
|
||||
const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
|
||||
|
||||
const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
|
||||
const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
|
||||
const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
|
||||
const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
|
||||
|
||||
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
|
||||
const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
|
||||
const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
|
||||
const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
|
||||
const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
|
||||
const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
|
||||
const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
|
||||
const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
|
||||
const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
|
||||
|
||||
// AVX2 full_signs_1 is full_sign_bits_0 here
|
||||
// AVX2 full_signs_2 is full_sign_bits_1 here
|
||||
__m128i signs_0, signs_1;
|
||||
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
|
||||
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
|
||||
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
|
||||
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
|
||||
|
||||
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
|
||||
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
|
||||
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
|
||||
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
|
||||
|
||||
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
|
||||
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
|
||||
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||
const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
|
||||
const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
|
||||
|
||||
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
|
||||
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
|
||||
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
|
||||
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
|
||||
const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
|
||||
const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
|
||||
|
||||
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||
const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
|
||||
const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
|
||||
const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
|
||||
const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
|
||||
|
||||
__m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
|
||||
const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||
const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
|
||||
const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||
const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
|
||||
const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||
const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
|
||||
const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
|
||||
const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
|
||||
|
||||
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
|
||||
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
|
||||
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
|
||||
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
|
||||
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
|
||||
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
|
||||
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
|
||||
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
|
||||
}
|
||||
|
||||
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
#elif defined(__loongarch_asx)
|
||||
|
||||
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue