A slightly daster Q4_K AVX2 dot product
For perplexity, where we are less memory bound, time per pass drops by ~5%. Barely measurable difference for single token prediction.
This commit is contained in:
parent
9a9c5a0c80
commit
894210a351
1 changed files with 22 additions and 12 deletions
34
k_quants.c
34
k_quants.c
|
@ -1499,6 +1499,16 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
}
|
||||
|
||||
#ifdef __AVX__
|
||||
static inline int32_t hsum_int_4(__m128i v) {
|
||||
__m128i hi64 = _mm_unpackhi_epi64(v, v);
|
||||
__m128i sum64 = _mm_add_epi32(hi64, v);
|
||||
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2));
|
||||
__m128i sum32 = _mm_add_epi32(sum64, hi32);
|
||||
return _mm_cvtsi128_si32(sum32);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % QK_K == 0);
|
||||
|
||||
|
@ -1596,11 +1606,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
#elif defined __AVX2__
|
||||
|
||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||
const __m128i mzero = _mm_setzero_si128();
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
float summs = 0.f;
|
||||
__m128 acc_m = _mm_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
|
@ -1622,8 +1630,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
||||
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
||||
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
||||
summs += dmin * _mm_extract_epi32(hsum, 0);
|
||||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
||||
|
@ -1638,16 +1645,16 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
||||
const __m256i q4l = _mm256_and_si256(q4bits, m4);
|
||||
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
||||
|
||||
const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
__m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||
|
||||
p16l = _mm256_madd_epi16(scale_l, p16l);
|
||||
p16h = _mm256_madd_epi16(scale_h, p16h);
|
||||
sumi = _mm256_add_epi32(sumi, p16l);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16l, p16h));
|
||||
const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||
p16h = _mm256_madd_epi16(scale_h, p16h);
|
||||
sumi = _mm256_add_epi32(sumi, p16h);
|
||||
|
||||
}
|
||||
|
||||
|
@ -1656,7 +1663,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
||||
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
||||
|
||||
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
||||
|
||||
#else
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue