k_quants : apply review comments

This commit is contained in:
katsu560 2023-06-25 15:27:17 +09:00
parent ce95172904
commit 3d30586c31

View file

@ -1549,7 +1549,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
uint32_t aux[3]; uint32_t *aux;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -1559,7 +1559,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
// Set up scales // Set up scales
memcpy(aux, x[i].scales, 12); aux = (uint32_t *)x[i].scales;
__m128i scales128 = _mm_set_epi32( __m128i scales128 = _mm_set_epi32(
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
@ -1578,37 +1578,32 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
__m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128();
int bit = 0;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
// load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
// prepare low and high bits // prepare low and high bits
const int bit = j << 2;
const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
++bit;
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
++bit;
const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
++bit;
const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
++bit;
// load Q8 quants from block_q8_K.qs[QK_K] // load Q8 quants from block_q8_K.qs[QK_K]
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@ -1970,21 +1965,21 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const __m128i q4l_1 = _mm_and_si128(q4bits, m4); const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
__m128i q8l = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16l = _mm_maddubs_epi16(q4l_0, q8l); __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
p16l = _mm_madd_epi16(scale_l, p16l); p16l = _mm_madd_epi16(scale_l, p16l);
sumi_0 = _mm_add_epi32(sumi_0, p16l); sumi_0 = _mm_add_epi32(sumi_0, p16l);
q8l = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
p16l = _mm_maddubs_epi16(q4l_1, q8l); p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
p16l = _mm_madd_epi16(scale_l, p16l); p16l = _mm_madd_epi16(scale_l, p16l);
sumi_1 = _mm_add_epi32(sumi_1, p16l); sumi_1 = _mm_add_epi32(sumi_1, p16l);
__m128i q8h = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16h = _mm_maddubs_epi16(q4h_0, q8h); __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
p16h = _mm_madd_epi16(scale_h, p16h); p16h = _mm_madd_epi16(scale_h, p16h);
sumi_0 = _mm_add_epi32(sumi_0, p16h); sumi_0 = _mm_add_epi32(sumi_0, p16h);
q8h = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
p16h = _mm_maddubs_epi16(q4h_1, q8h); p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
p16h = _mm_madd_epi16(scale_h, p16h); p16h = _mm_madd_epi16(scale_h, p16h);
sumi_1 = _mm_add_epi32(sumi_1, p16h); sumi_1 = _mm_add_epi32(sumi_1, p16h);
@ -2253,7 +2248,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
float summs = 0.f; float summs = 0.f;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);