add AVX to ggml_vec_dot_q2_K_q8_K()
This commit is contained in:
parent
b9b7d94fc1
commit
6afbb11c01
1 changed files with 56 additions and 0 deletions
56
k_quants.c
56
k_quants.c
|
@ -1666,6 +1666,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m3 = _mm_set1_epi8(3);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
uint32_t ud, um;
|
||||
const uint8_t * restrict db = (const uint8_t *)&ud;
|
||||
const uint8_t * restrict mb = (const uint8_t *)&um;
|
||||
|
||||
float summs = 0;
|
||||
|
||||
// TODO: optimize this
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
const uint8_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
||||
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
||||
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
||||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
||||
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
||||
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
||||
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
||||
const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
||||
const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
||||
const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
||||
#else
|
||||
|
||||
float sumf = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue