diff --git a/ggml.c b/ggml.c index f3e4b19c4..be428e28e 100644 --- a/ggml.c +++ b/ggml.c @@ -505,6 +505,11 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { // multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else // Get absolute values of x vectors const __m256i ax = _mm256_sign_epi8(x, x); // Sign the values of the y vectors @@ -512,6 +517,7 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); return sum_i16_pairs_float(dot); +#endif } static inline __m128i packNibbles( __m256i bytes )