ggml : also faster TQ1_0

Same optimization as for TQ2_0 by offsetting the sum instead of the weights.
This makes TQ1_0 almost as fast as Q8_0 on AVX2.
This commit is contained in:
Francis Couture-Harpin 2024-07-31 00:06:21 -04:00
parent 560873f337
commit e9719576c4

View file

@ -5963,32 +5963,17 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
// 0, 1, 2 => -1, 0, 1
qx0 = _mm256_sub_epi8(qx0, _mm256_set1_epi8(1));
qx1 = _mm256_sub_epi8(qx1, _mm256_set1_epi8(1));
qx2 = _mm256_sub_epi8(qx2, _mm256_set1_epi8(1));
qx3 = _mm256_sub_epi8(qx3, _mm256_set1_epi8(1));
qx4 = _mm256_sub_epi8(qx4, _mm256_set1_epi8(1));
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
// dot qx0 = _mm256_maddubs_epi16(qx0, qy0);
qx0 = _mm256_sign_epi8(qy0, qx0); qx1 = _mm256_maddubs_epi16(qx1, qy1);
qx1 = _mm256_sign_epi8(qy1, qx1); qx2 = _mm256_maddubs_epi16(qx2, qy2);
qx2 = _mm256_sign_epi8(qy2, qx2); qx3 = _mm256_maddubs_epi16(qx3, qy3);
qx3 = _mm256_sign_epi8(qy3, qx3); qx4 = _mm256_maddubs_epi16(qx4, qy4);
qx4 = _mm256_sign_epi8(qy4, qx4);
// widening addition
qx0 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx0);
qx1 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx1);
qx2 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx2);
qx3 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx3);
qx4 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx4);
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
@ -6025,32 +6010,23 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
// 0, 1, 2 => -1, 0, 1
qx01 = _mm256_sub_epi8(qx01, _mm256_set1_epi8(1));
qx23 = _mm256_sub_epi8(qx23, _mm256_set1_epi8(1));
qx45 = _mm256_sub_epi8(qx45, _mm256_set1_epi8(1));
const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
// dot qx01 = _mm256_maddubs_epi16(qx01, qy01);
qx01 = _mm256_sign_epi8(qy01, qx01); qx23 = _mm256_maddubs_epi16(qx23, qy23);
qx23 = _mm256_sign_epi8(qy23, qx23); qx45 = _mm256_maddubs_epi16(qx45, qy45);
qx45 = _mm256_sign_epi8(qy45, qx45);
// widening addition
qx01 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx01);
qx23 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx23);
qx45 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx45);
sumi0 = _mm256_add_epi16(sumi0, qx01); sumi0 = _mm256_add_epi16(sumi0, qx01);
sumi1 = _mm256_add_epi16(sumi1, qx23); sumi1 = _mm256_add_epi16(sumi1, qx23);
sumi2 = _mm256_add_epi16(sumi2, qx45); sumi2 = _mm256_add_epi16(sumi2, qx45);
} }
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
sumi0 = _mm256_sub_epi16(sumi0, ysum);
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));