diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index eb71aa9aa..9f6d91ed5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5963,32 +5963,17 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); - // 0, 1, 2 => -1, 0, 1 - qx0 = _mm256_sub_epi8(qx0, _mm256_set1_epi8(1)); - qx1 = _mm256_sub_epi8(qx1, _mm256_set1_epi8(1)); - qx2 = _mm256_sub_epi8(qx2, _mm256_set1_epi8(1)); - qx3 = _mm256_sub_epi8(qx3, _mm256_set1_epi8(1)); - qx4 = _mm256_sub_epi8(qx4, _mm256_set1_epi8(1)); - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); - // dot - qx0 = _mm256_sign_epi8(qy0, qx0); - qx1 = _mm256_sign_epi8(qy1, qx1); - qx2 = _mm256_sign_epi8(qy2, qx2); - qx3 = _mm256_sign_epi8(qy3, qx3); - qx4 = _mm256_sign_epi8(qy4, qx4); - - // widening addition - qx0 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx0); - qx1 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx1); - qx2 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx2); - qx3 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx3); - qx4 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx4); + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); @@ -6025,32 +6010,23 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); - // 0, 1, 2 => -1, 0, 1 - qx01 = _mm256_sub_epi8(qx01, _mm256_set1_epi8(1)); - qx23 = _mm256_sub_epi8(qx23, _mm256_set1_epi8(1)); - qx45 = _mm256_sub_epi8(qx45, _mm256_set1_epi8(1)); - const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); - // dot - qx01 = _mm256_sign_epi8(qy01, qx01); - qx23 = _mm256_sign_epi8(qy23, qx23); - qx45 = _mm256_sign_epi8(qy45, qx45); - - // widening addition - qx01 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx01); - qx23 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx23); - qx45 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx45); + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); sumi0 = _mm256_add_epi16(sumi0, qx01); sumi1 = _mm256_add_epi16(sumi1, qx23); sumi2 = _mm256_add_epi16(sumi2, qx45); } + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + sumi0 = _mm256_sub_epi16(sumi0, ysum); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));