diff --git a/k_quants.c b/k_quants.c index 0214d724d..0fa586aaf 100644 --- a/k_quants.c +++ b/k_quants.c @@ -4112,6 +4112,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m128i q4h_0l = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_0h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_1l = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_1h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + + const __m128i q4_0l = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0l); + const __m128i q4_0h = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_0h); + const __m128i q4_1l = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_1l); + const __m128i q4_1h = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_1h); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m128i q8s_0l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_0h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_1l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_1h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0l = _mm_maddubs_epi16(q4_0l, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_0h = _mm_maddubs_epi16(q4_0h, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_1l = _mm_maddubs_epi16(q4_1l, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_1h = _mm_maddubs_epi16(q4_1h, _mm256_extractf128_si256(q8_1, 1)); + + p16_0l = _mm_sub_epi16(p16_0l, q8s_0l); + p16_0h = _mm_sub_epi16(p16_0h, q8s_0h); + p16_1l = _mm_sub_epi16(p16_1l, q8s_1l); + p16_1h = _mm_sub_epi16(p16_1h, q8s_1h); + + p16_0l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0l); + p16_0h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_0h); + p16_1l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_1l); + p16_1h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_1h); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0l, p16_1l)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_0h, p16_1h)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc); + } + + *s = hsum_float_8(acc); + #else int8_t aux8[QK_K];