A slightly daster Q4_K AVX2 dot product

For perplexity, where we are less memory bound, time per
pass drops by ~5%. Barely measurable difference for single
token prediction.
This commit is contained in:
Iwan Kawrakow 2023-06-02 17:28:38 +03:00
parent 9a9c5a0c80
commit 894210a351

View file

@ -1499,6 +1499,16 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
} }
#ifdef __AVX__
static inline int32_t hsum_int_4(__m128i v) {
__m128i hi64 = _mm_unpackhi_epi64(v, v);
__m128i sum64 = _mm_add_epi32(hi64, v);
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32);
}
#endif
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
@ -1596,11 +1606,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#elif defined __AVX2__ #elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m4 = _mm256_set1_epi8(0xF);
const __m128i mzero = _mm_setzero_si128();
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
__m128 acc_m = _mm_setzero_ps();
float summs = 0.f;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -1622,8 +1630,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
summs += dmin * _mm_extract_epi32(hsum, 0);
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
const __m256i scales = _mm256_set_m128i(sc128, sc128); const __m256i scales = _mm256_set_m128i(sc128, sc128);
@ -1638,16 +1645,16 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
const __m256i q4l = _mm256_and_si256(q4bits, m4); const __m256i q4l = _mm256_and_si256(q4bits, m4);
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i p16l = _mm256_maddubs_epi16(q4l, q8l); __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
p16l = _mm256_madd_epi16(scale_l, p16l); p16l = _mm256_madd_epi16(scale_l, p16l);
p16h = _mm256_madd_epi16(scale_h, p16h); sumi = _mm256_add_epi32(sumi, p16l);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16l, p16h)); const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
p16h = _mm256_madd_epi16(scale_h, p16h);
sumi = _mm256_add_epi32(sumi, p16h);
} }
@ -1656,7 +1663,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
} }
*s = hsum_float_8(acc) + summs; acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
#else #else