add AVX to ggml_vec_dot_q5_K_q8_K()

This commit is contained in:
katsu560 2023-07-23 18:17:12 +09:00
parent 56df218f56
commit dfd2e0e5d6

View file

@ -2953,26 +2953,26 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
const __m128i q4ll = _mm_and_si128(q4bits_0, m4);
const __m128i q4lh = _mm_and_si128(q4bits_1, m4);
const __m128i q4hl = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
const __m128i q4hh = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p16ll = _mm_maddubs_epi16(q4ll, _mm256_extractf128_si256(q8l, 0));
const __m128i p16lh = _mm_maddubs_epi16(q4lh, _mm256_extractf128_si256(q8l, 1));
const __m128i p16hl = _mm_maddubs_epi16(q4hl, _mm256_extractf128_si256(q8h, 0));
const __m128i p16hh = _mm_maddubs_epi16(q4hh, _mm256_extractf128_si256(q8h, 1));
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
const __m128i p32ll = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16ll);
const __m128i p32lh = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16lh);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32lh, p32ll))), acc);
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc);
const __m128i p32hl = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hl);
const __m128i p32hh = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16hh);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32hh, p32hl))), acc);
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc);
}
@ -3492,6 +3492,63 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i mone = _mm_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
int64_t aux64;
memcpy(&aux64, x[i].qh, 8);
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc);
}
*s = hsum_float_8(acc);
#else