ggml : add AVX ggml_vec_dot_q4_0()
This commit is contained in:
parent
79e14129e1
commit
93a3169284
1 changed files with 63 additions and 0 deletions
63
ggml.c
63
ggml.c
|
@ -462,6 +462,23 @@ static inline __m128i packNibbles( __m256i bytes )
|
||||||
return _mm_packus_epi16( r0, r1 );
|
return _mm_packus_epi16( r0, r1 );
|
||||||
}
|
}
|
||||||
#elif __AVX__
|
#elif __AVX__
|
||||||
|
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
||||||
|
{
|
||||||
|
// Load 8 bytes from memory
|
||||||
|
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
||||||
|
|
||||||
|
// Expand bytes into uint16_t values
|
||||||
|
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
||||||
|
|
||||||
|
// Unpack values into individual bytes
|
||||||
|
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
||||||
|
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
||||||
|
__m128i low = _mm_and_si128( lowMask, bytes );
|
||||||
|
high = _mm_slli_epi16( high, 4 );
|
||||||
|
bytes = _mm_or_si128( low, high );
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||||
{
|
{
|
||||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||||
|
@ -1983,6 +2000,52 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
||||||
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
||||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||||
|
|
||||||
|
sumf = _mm_cvtss_f32( res );
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
// Initialize accumulator with zeros
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
// Compute combined scale for the block
|
||||||
|
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
||||||
|
|
||||||
|
__m128i i32[2];
|
||||||
|
for (int j = 0; j < 2; ++j) {
|
||||||
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
||||||
|
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
||||||
|
__m128i by = bytesFromNibbles( y[i].qs + 8*j );
|
||||||
|
|
||||||
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||||
|
const __m128i off = _mm_set1_epi8( 8 );
|
||||||
|
bx = _mm_sub_epi8( bx, off );
|
||||||
|
by = _mm_sub_epi8( by, off );
|
||||||
|
|
||||||
|
// Sign-extend first 8 signed bytes into int16_t
|
||||||
|
__m128i x16 = _mm_cvtepi8_epi16( bx );
|
||||||
|
__m128i y16 = _mm_cvtepi8_epi16( by );
|
||||||
|
// Compute products of int16_t integers, add pairwise
|
||||||
|
i32[j] = _mm_madd_epi16( x16, y16 );
|
||||||
|
|
||||||
|
// Sign-extend last 8 signed bytes into int16_t vectors
|
||||||
|
x16 = _mm_cvtepi8_epi16( _mm_srli_si128( bx, 8 ) );
|
||||||
|
y16 = _mm_cvtepi8_epi16( _mm_srli_si128( by, 8 ) );
|
||||||
|
// Accumulate products of int16_t integers
|
||||||
|
i32[j] = _mm_add_epi32( i32[j], _mm_madd_epi16( x16, y16 ) );
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert int32_t to float
|
||||||
|
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
||||||
|
// Apply the scale, and accumulate
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return horizontal sum of the acc vector
|
||||||
|
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
||||||
|
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
||||||
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
||||||
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||||
|
|
||||||
sumf = _mm_cvtss_f32( res );
|
sumf = _mm_cvtss_f32( res );
|
||||||
#elif defined(__wasm_simd128__)
|
#elif defined(__wasm_simd128__)
|
||||||
// wasm simd
|
// wasm simd
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue