From 894210a3519eb39148189ea7a4094aa076bee2d7 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 2 Jun 2023 17:28:38 +0300 Subject: [PATCH] A slightly daster Q4_K AVX2 dot product For perplexity, where we are less memory bound, time per pass drops by ~5%. Barely measurable difference for single token prediction. --- k_quants.c | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/k_quants.c b/k_quants.c index 79b021cf0..fa46b704c 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1499,6 +1499,16 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri } +#ifdef __AVX__ +static inline int32_t hsum_int_4(__m128i v) { + __m128i hi64 = _mm_unpackhi_epi64(v, v); + __m128i sum64 = _mm_add_epi32(hi64, v); + __m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); +} +#endif + void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1596,11 +1606,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; + __m128 acc_m = _mm_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -1622,8 +1630,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); const __m256i scales = _mm256_set_m128i(sc128, sc128); @@ -1638,16 +1645,16 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4l = _mm256_and_si256(q4bits, m4); const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16l = _mm256_madd_epi16(scale_l, p16l); - p16h = _mm256_madd_epi16(scale_h, p16h); + sumi = _mm256_add_epi32(sumi, p16l); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16l, p16h)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + sumi = _mm256_add_epi32(sumi, p16h); } @@ -1656,7 +1663,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri } - *s = hsum_float_8(acc) + summs; + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); #else