add AVX to ggml_vec_dot_q3_K_q8_K()
This commit is contained in:
parent
6afbb11c01
commit
2024d40a00
1 changed files with 87 additions and 0 deletions
87
k_quants.c
87
k_quants.c
|
@ -2351,6 +2351,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m3 = _mm_set1_epi8(3);
|
||||||
|
const __m128i m1 = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
uint64_t aux64;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const int8_t * aux8 = (const int8_t *)aux16;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q3 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||||
|
aux16[0] = a & 0x0f0f;
|
||||||
|
aux16[1] = (a >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
|
||||||
|
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
|
||||||
|
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
|
||||||
|
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
|
||||||
|
|
||||||
|
memcpy(&aux64, x[i].hmask, 8);
|
||||||
|
|
||||||
|
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||||
|
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
|
||||||
|
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
|
||||||
|
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
|
||||||
|
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
|
||||||
|
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
|
||||||
|
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
|
||||||
|
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
|
||||||
|
|
||||||
|
// load low 2 bits
|
||||||
|
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||||
|
|
||||||
|
// prepare low and high bits
|
||||||
|
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
|
||||||
|
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
|
||||||
|
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
|
||||||
|
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
|
||||||
|
|
||||||
|
// load Q8 quants
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
||||||
|
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
||||||
|
// and 2 if the high bit was set)
|
||||||
|
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
||||||
|
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
||||||
|
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
||||||
|
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
||||||
|
|
||||||
|
// multiply with scales
|
||||||
|
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
||||||
|
p16_1 = _mm_madd_epi16(scale_1, p16_1);
|
||||||
|
p16_2 = _mm_madd_epi16(scale_2, p16_2);
|
||||||
|
p16_3 = _mm_madd_epi16(scale_3, p16_3);
|
||||||
|
|
||||||
|
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
||||||
|
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
||||||
|
__m256i p16 = _mm256_set_m128i(p16_1, p16_0);
|
||||||
|
|
||||||
|
// multiply with block scale and accumulate
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int8_t aux8[QK_K];
|
int8_t aux8[QK_K];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue