refactor AVX code in ggml_vec_dot_q6_K_q8_K()

This commit is contained in:
katsu560 2023-07-23 19:28:48 +09:00
parent 4775602c9f
commit 5f98fc2f90

View file

@ -4142,41 +4142,41 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
const __m128i q4h_0l = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
const __m128i q4h_0h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
const __m128i q4h_1l = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
const __m128i q4h_1h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
const __m128i q4_0l = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0l);
const __m128i q4_0h = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_0h);
const __m128i q4_1l = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_1l);
const __m128i q4_1h = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_1h);
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
__m128i q8s_0l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
__m128i q8s_0h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
__m128i q8s_1l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
__m128i q8s_1h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
__m128i p16_0l = _mm_maddubs_epi16(q4_0l, _mm256_extractf128_si256(q8_0, 0));
__m128i p16_0h = _mm_maddubs_epi16(q4_0h, _mm256_extractf128_si256(q8_0, 1));
__m128i p16_1l = _mm_maddubs_epi16(q4_1l, _mm256_extractf128_si256(q8_1, 0));
__m128i p16_1h = _mm_maddubs_epi16(q4_1h, _mm256_extractf128_si256(q8_1, 1));
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
p16_0l = _mm_sub_epi16(p16_0l, q8s_0l);
p16_0h = _mm_sub_epi16(p16_0h, q8s_0h);
p16_1l = _mm_sub_epi16(p16_1l, q8s_1l);
p16_1h = _mm_sub_epi16(p16_1h, q8s_1h);
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
p16_0l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0l);
p16_0h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_0h);
p16_1l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_1l);
p16_1h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_1h);
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0l, p16_1l));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_0h, p16_1h));
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc);
}