add AVX to ggml_vec_dot_q4_K_q8_K()
This commit is contained in:
parent
2024d40a00
commit
56df218f56
1 changed files with 55 additions and 1 deletions
56
k_quants.c
56
k_quants.c
|
@ -2403,7 +2403,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
|
||||
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
||||
// and 2 if the high bit was set)
|
||||
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
|
@ -2924,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
*s = hsum_float_8(acc) - summs;
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
float summs = 0;
|
||||
|
||||
uint16_t aux16[2];
|
||||
const uint8_t * scales = (const uint8_t *)aux16;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
||||
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
||||
const __m256 vd = _mm256_set1_ps(d);
|
||||
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
aux16[0] = a[0] & 0x0f0f;
|
||||
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||
|
||||
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||
|
||||
const uint8_t * restrict q4 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
||||
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
|
||||
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
|
||||
const __m128i q4ll = _mm_and_si128(q4bits_0, m4);
|
||||
const __m128i q4lh = _mm_and_si128(q4bits_1, m4);
|
||||
const __m128i q4hl = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
|
||||
const __m128i q4hh = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
|
||||
|
||||
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m128i p16ll = _mm_maddubs_epi16(q4ll, _mm256_extractf128_si256(q8l, 0));
|
||||
const __m128i p16lh = _mm_maddubs_epi16(q4lh, _mm256_extractf128_si256(q8l, 1));
|
||||
const __m128i p16hl = _mm_maddubs_epi16(q4hl, _mm256_extractf128_si256(q8h, 0));
|
||||
const __m128i p16hh = _mm_maddubs_epi16(q4hh, _mm256_extractf128_si256(q8h, 1));
|
||||
|
||||
const __m128i p32ll = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16ll);
|
||||
const __m128i p32lh = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16lh);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32lh, p32ll))), acc);
|
||||
|
||||
const __m128i p32hl = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hl);
|
||||
const __m128i p32hh = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hh);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32hh, p32hl))), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) - summs;
|
||||
|
||||
#else
|
||||
|
||||
uint8_t aux8[QK_K];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue