diff --git a/k_quants.c b/k_quants.c index 6498e5f9d..6b398abcd 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2403,7 +2403,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, // and 2 if the high bit was set) const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); @@ -2924,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) - summs; +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4ll = _mm_and_si128(q4bits_0, m4); + const __m128i q4lh = _mm_and_si128(q4bits_1, m4); + const __m128i q4hl = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4hh = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16ll = _mm_maddubs_epi16(q4ll, _mm256_extractf128_si256(q8l, 0)); + const __m128i p16lh = _mm_maddubs_epi16(q4lh, _mm256_extractf128_si256(q8l, 1)); + const __m128i p16hl = _mm_maddubs_epi16(q4hl, _mm256_extractf128_si256(q8h, 0)); + const __m128i p16hh = _mm_maddubs_epi16(q4hh, _mm256_extractf128_si256(q8h, 1)); + + const __m128i p32ll = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16ll); + const __m128i p32lh = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16lh); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32lh, p32ll))), acc); + + const __m128i p32hl = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hl); + const __m128i p32hh = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hh); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32hh, p32hl))), acc); + + } + + *s = hsum_float_8(acc) - summs; + #else uint8_t aux8[QK_K];