diff --git a/ggml.c b/ggml.c index 7f1c4409d..05710559a 100644 --- a/ggml.c +++ b/ggml.c @@ -2021,17 +2021,17 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest bx = _mm_sub_epi8( bx, off ); by = _mm_sub_epi8( by, off ); - // Sign-extend first 8 signed bytes into int16_t - __m128i x16 = _mm_cvtepi8_epi16( bx ); - __m128i y16 = _mm_cvtepi8_epi16( by ); - // Compute products of int16_t integers, add pairwise - i32[j] = _mm_madd_epi16( x16, y16 ); + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(bx, bx); - // Sign-extend last 8 signed bytes into int16_t vectors - x16 = _mm_cvtepi8_epi16( _mm_srli_si128( bx, 8 ) ); - y16 = _mm_cvtepi8_epi16( _mm_srli_si128( by, 8 ) ); - // Accumulate products of int16_t integers - i32[j] = _mm_add_epi32( i32[j], _mm_madd_epi16( x16, y16 ) ); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + + const __m128i ones = _mm_set1_epi16(1); + i32[j] = _mm_madd_epi16(ones, dot); } // Convert int32_t to float