diff --git a/k_quants.c b/k_quants.c index 0fa586aaf..c75361bee 100644 --- a/k_quants.c +++ b/k_quants.c @@ -4142,41 +4142,41 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - const __m128i q4h_0l = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); - const __m128i q4h_0h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); - const __m128i q4h_1l = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); - const __m128i q4h_1h = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); - const __m128i q4_0l = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0l); - const __m128i q4_0h = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_0h); - const __m128i q4_1l = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_1l); - const __m128i q4_1h = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_1h); + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - __m128i q8s_0l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); - __m128i q8s_0h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); - __m128i q8s_1l = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); - __m128i q8s_1h = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); - __m128i p16_0l = _mm_maddubs_epi16(q4_0l, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_0h = _mm_maddubs_epi16(q4_0h, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_1l = _mm_maddubs_epi16(q4_1l, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_1h = _mm_maddubs_epi16(q4_1h, _mm256_extractf128_si256(q8_1, 1)); + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - p16_0l = _mm_sub_epi16(p16_0l, q8s_0l); - p16_0h = _mm_sub_epi16(p16_0h, q8s_0h); - p16_1l = _mm_sub_epi16(p16_1l, q8s_1l); - p16_1h = _mm_sub_epi16(p16_1h, q8s_1h); + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_0l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0l); - p16_0h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_0h); - p16_1l = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_1l); - p16_1h = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_1h); + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0l, p16_1l)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_0h, p16_1h)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc); }