add AVX to ggml_vec_dot_q3_K_q8_K()

This commit is contained in:
katsu560 2023-07-23 16:58:27 +09:00
parent 6afbb11c01
commit 2024d40a00

View file

@ -2351,6 +2351,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m3 = _mm_set1_epi8(3);
const __m128i m1 = _mm_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
uint64_t aux64;
uint16_t aux16[2];
const int8_t * aux8 = (const int8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
memcpy(&aux64, x[i].hmask, 8);
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
// load low 2 bits
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
// prepare low and high bits
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
// load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
// multiply with scales
p16_0 = _mm_madd_epi16(scale_0, p16_0);
p16_1 = _mm_madd_epi16(scale_1, p16_1);
p16_2 = _mm_madd_epi16(scale_2, p16_2);
p16_3 = _mm_madd_epi16(scale_3, p16_3);
p16_0 = _mm_add_epi32(p16_0, p16_2);
p16_1 = _mm_add_epi32(p16_1, p16_3);
__m256i p16 = _mm256_set_m128i(p16_1, p16_0);
// multiply with block scale and accumulate
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
}
*s = hsum_float_8(acc);
#else #else
int8_t aux8[QK_K]; int8_t aux8[QK_K];