add AVX to ggml_vec_dot_q4_K_q8_K()

This commit is contained in:
katsu560 2023-07-23 17:26:41 +09:00
parent 2024d40a00
commit 56df218f56

View file

@ -2403,7 +2403,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
@ -2924,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
*s = hsum_float_8(acc) - summs;
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
__m256 acc = _mm256_setzero_ps();
float summs = 0;
uint16_t aux16[2];
const uint8_t * scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
const __m256 vd = _mm256_set1_ps(d);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
const __m128i q4ll = _mm_and_si128(q4bits_0, m4);
const __m128i q4lh = _mm_and_si128(q4bits_1, m4);
const __m128i q4hl = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
const __m128i q4hh = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p16ll = _mm_maddubs_epi16(q4ll, _mm256_extractf128_si256(q8l, 0));
const __m128i p16lh = _mm_maddubs_epi16(q4lh, _mm256_extractf128_si256(q8l, 1));
const __m128i p16hl = _mm_maddubs_epi16(q4hl, _mm256_extractf128_si256(q8h, 0));
const __m128i p16hh = _mm_maddubs_epi16(q4hh, _mm256_extractf128_si256(q8h, 1));
const __m128i p32ll = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16ll);
const __m128i p32lh = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16lh);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32lh, p32ll))), acc);
const __m128i p32hl = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hl);
const __m128i p32hh = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hh);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32hh, p32hl))), acc);
}
*s = hsum_float_8(acc) - summs;
#else
uint8_t aux8[QK_K];