~7% faster Q5_1 AVX2 code
X is always positive (range 0..31) in ggml_vec_dot_q5_1_q8_1(), so there's no need for two _mm256_sign_epi8().
This commit is contained in:
parent
2a5ee023ad
commit
2d61ed7fc6
1 changed files with 12 additions and 8 deletions
20
ggml.c
20
ggml.c
|
@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|||
return _mm256_cvtepi32_ps(summed_pairs);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as float vector
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||
#if __AVXVNNI__
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
||||
|
@ -560,6 +555,15 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as float vector
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||
return mul_sum_us8_pairs_float(ax, sy);
|
||||
}
|
||||
|
||||
static inline __m128i packNibbles( __m256i bytes )
|
||||
{
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
|
@ -2906,7 +2910,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx, by);
|
||||
|
||||
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
||||
}
|
||||
|
@ -2940,7 +2944,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx, by);
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue