Make the intrinsics more readable

This commit is contained in:
Justine Tunney 2023-05-15 23:11:47 -07:00
parent 210187cf77
commit 80db9de173
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
75 changed files with 12444 additions and 21493 deletions

View file

@ -1784,24 +1784,40 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
//
// Main loop
for (int i = 0; i < nb; ++i) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m256i bx = bytes_from_nibbles_32(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc );
//
#define WORK(I) \
/* Compute combined scale for the block */ \
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[I].d ), _mm256_broadcast_ss( &y[I].d ) ); \
__m256i bx = bytes_from_nibbles_32(x[I].qs); \
/* Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */ \
const __m256i off = _mm256_set1_epi8( 8 ); \
bx = _mm256_sub_epi8( bx, off ); \
__m256i by = _mm256_loadu_si256((const __m256i *)y[I].qs); \
const __m256 q = mul_sum_i8_pairs_float(bx, by); \
/* Multiply q with scale and accumulate */ \
acc = _mm256_fmadd_ps( d, q, acc )
int i = 0;
for (; i + 12 < nb; i += 12) {
_mm_prefetch(x+i+12, 3);
_mm_prefetch(x+i+15, 3);
_mm_prefetch(x+i+18, 3);
_mm_prefetch(x+i+21, 3);
_mm_prefetch(y+i+12, 3);
_mm_prefetch(y+i+14, 3);
_mm_prefetch(y+i+16, 3);
_mm_prefetch(y+i+18, 3);
_mm_prefetch(y+i+20, 3);
_mm_prefetch(y+i+22, 3);
for (int j = 0; j < 12; ++j) {
WORK(i+j);
}
}
for (; i < nb; ++i) {
WORK(i);
}
#undef WORK
*s = hsum_float_8(acc);
#elif defined(__AVX__)