k_quants: AVX2 implementation for new 64-weight Q5_K

This commit is contained in:
Iwan Kawrakow 2023-06-24 18:17:35 +03:00
parent ccf4901334
commit 2da3a59708

View file

@ -2821,55 +2821,51 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
*s = sumf;
#elif defined z__AVX2__
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m1 = _mm256_set1_epi16(1);
const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
float summs = 0.f;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const float d1 = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
const float d2 = y[i].d * ggml_fp16_to_fp32(x[i].d[2]);
summs -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)(q5+ 0));
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
const uint64_t * restrict sc = (const uint64_t *)x[i].qh;
const __m128i hbits_0 = _mm_set_epi64x(sc[0] >> 1, sc[0] >> 0);
const __m128i hbits_1 = _mm_set_epi64x(sc[0] >> 3, sc[0] >> 2);
const __m128i hbits_2 = _mm_set_epi64x(sc[0] >> 5, sc[0] >> 4);
const __m128i hbits_3 = _mm_set_epi64x(sc[0] >> 7, sc[0] >> 6);
const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_1, hbits_0), mone), 4);
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_3, hbits_2), mone), 4);
int64_t aux64;
memcpy(&aux64, x[i].qh, 8);
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i p16_0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_0, q8_0));
const __m256i p16_1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_1, q8_1));
const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
acc = _mm256_fmadd_ps(_mm256_set1_ps(d1), _mm256_cvtepi32_ps(p16_0), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(d2), _mm256_cvtepi32_ps(p16_1), acc);
const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
}
*s = hsum_float_8(acc) + summs;
*s = hsum_float_8(acc);
#else