diff --git a/ggml.c b/ggml.c index d9ae20140..7f1c4409d 100644 --- a/ggml.c +++ b/ggml.c @@ -462,6 +462,23 @@ static inline __m128i packNibbles( __m256i bytes ) return _mm_packus_epi16( r0, r1 ); } #elif __AVX__ +static inline __m128i bytesFromNibbles( const uint8_t* rsi ) +{ + // Load 8 bytes from memory + __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi ); + + // Expand bytes into uint16_t values + __m128i bytes = _mm_cvtepu8_epi16( tmp ); + + // Unpack values into individual bytes + const __m128i lowMask = _mm_set1_epi8( 0xF ); + __m128i high = _mm_andnot_si128( lowMask, bytes ); + __m128i low = _mm_and_si128( lowMask, bytes ); + high = _mm_slli_epi16( high, 4 ); + bytes = _mm_or_si128( low, high ); + return bytes; +} + static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -1983,6 +2000,52 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + sumf = _mm_cvtss_f32( res ); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + __m128i i32[2]; + for (int j = 0; j < 2; ++j) { + // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes + __m128i bx = bytesFromNibbles( x[i].qs + 8*j ); + __m128i by = bytesFromNibbles( y[i].qs + 8*j ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m128i off = _mm_set1_epi8( 8 ); + bx = _mm_sub_epi8( bx, off ); + by = _mm_sub_epi8( by, off ); + + // Sign-extend first 8 signed bytes into int16_t + __m128i x16 = _mm_cvtepi8_epi16( bx ); + __m128i y16 = _mm_cvtepi8_epi16( by ); + // Compute products of int16_t integers, add pairwise + i32[j] = _mm_madd_epi16( x16, y16 ); + + // Sign-extend last 8 signed bytes into int16_t vectors + x16 = _mm_cvtepi8_epi16( _mm_srli_si128( bx, 8 ) ); + y16 = _mm_cvtepi8_epi16( _mm_srli_si128( by, 8 ) ); + // Accumulate products of int16_t integers + i32[j] = _mm_add_epi32( i32[j], _mm_madd_epi16( x16, y16 ) ); + } + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] )); + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + sumf = _mm_cvtss_f32( res ); #elif defined(__wasm_simd128__) // wasm simd