ggml : add AVX ggml_vec_dot_q4_0()

This commit is contained in:
Sergey Pershukov 2023-03-30 09:50:40 +05:00
parent 79e14129e1
commit 93a3169284

63
ggml.c
View file

@ -462,6 +462,23 @@ static inline __m128i packNibbles( __m256i bytes )
return _mm_packus_epi16( r0, r1 ); return _mm_packus_epi16( r0, r1 );
} }
#elif __AVX__ #elif __AVX__
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
{
// Load 8 bytes from memory
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m128i bytes = _mm_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i high = _mm_andnot_si128( lowMask, bytes );
__m128i low = _mm_and_si128( lowMask, bytes );
high = _mm_slli_epi16( high, 4 );
bytes = _mm_or_si128( low, high );
return bytes;
}
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{ {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@ -1983,6 +2000,52 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
__m128i by = bytesFromNibbles( y[i].qs + 8*j );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );
by = _mm_sub_epi8( by, off );
// Sign-extend first 8 signed bytes into int16_t
__m128i x16 = _mm_cvtepi8_epi16( bx );
__m128i y16 = _mm_cvtepi8_epi16( by );
// Compute products of int16_t integers, add pairwise
i32[j] = _mm_madd_epi16( x16, y16 );
// Sign-extend last 8 signed bytes into int16_t vectors
x16 = _mm_cvtepi8_epi16( _mm_srli_si128( bx, 8 ) );
y16 = _mm_cvtepi8_epi16( _mm_srli_si128( by, 8 ) );
// Accumulate products of int16_t integers
i32[j] = _mm_add_epi32( i32[j], _mm_madd_epi16( x16, y16 ) );
}
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ); sumf = _mm_cvtss_f32( res );
#elif defined(__wasm_simd128__) #elif defined(__wasm_simd128__)
// wasm simd // wasm simd