AVX2 implementation of ggml_vec_dot_q4_1_q8_0
This commit is contained in:
parent
ed24225917
commit
142c38a4f3
1 changed files with 56 additions and 0 deletions
56
ggml.c
56
ggml.c
|
@ -2518,6 +2518,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||||
|
#elif defined(__AVX2__)
|
||||||
|
// Initialize accumulator with zeros
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float * d0 = &x[i].d;
|
||||||
|
const float * d1 = &y[i].d;
|
||||||
|
const float * m0 = &x[i].m;
|
||||||
|
|
||||||
|
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
||||||
|
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
||||||
|
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
||||||
|
|
||||||
|
// Compute combined scales
|
||||||
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
||||||
|
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
||||||
|
|
||||||
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
||||||
|
const __m256i bx = bytesFromNibbles( x[i].qs );
|
||||||
|
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
||||||
|
|
||||||
|
// Get absolute values of x vectors
|
||||||
|
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
||||||
|
|
||||||
|
// Sign the values of the y vectors
|
||||||
|
const __m256i sy = _mm256_sign_epi8( by, bx );
|
||||||
|
|
||||||
|
// Perform multiplication and create 16-bit values
|
||||||
|
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
||||||
|
const __m256i ones = _mm256_set1_epi16( 1 );
|
||||||
|
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
||||||
|
|
||||||
|
// Convert to vector of 8 int32_t to 8 floats
|
||||||
|
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
||||||
|
|
||||||
|
// Accumulate d0*d1*x*y
|
||||||
|
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
||||||
|
|
||||||
|
// Compute sum of y values
|
||||||
|
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
||||||
|
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
||||||
|
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
||||||
|
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
||||||
|
|
||||||
|
// Accumulate d1*m0*y
|
||||||
|
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return horizontal sum of the acc vector
|
||||||
|
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
||||||
|
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
||||||
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
||||||
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||||
|
|
||||||
|
sumf = _mm_cvtss_f32( res );
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue