Split AVX512 process one block function out from inline
* Rename it to make it more clear that it's used for that dot product function
This commit is contained in:
parent
76af3f5a64
commit
05f2f48f70
1 changed files with 49 additions and 47 deletions
96
ggml.c
96
ggml.c
|
@ -397,6 +397,46 @@ static inline __m128i packNibbles( __m256i bytes )
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if __AVX512F__ && QK == 32
|
||||||
|
static inline __m512 dot_q4_0_oneblock_avx512(
|
||||||
|
__m512 acc,
|
||||||
|
const uint8_t * pd0,
|
||||||
|
const uint8_t * pd1,
|
||||||
|
const uint8_t * pb0,
|
||||||
|
const uint8_t * pb1,
|
||||||
|
size_t bs,
|
||||||
|
int i
|
||||||
|
) {
|
||||||
|
const float * d0_0 = (const float *) (pd0 + i*bs);
|
||||||
|
const float * d1_0 = (const float *) (pd1 + i*bs);
|
||||||
|
|
||||||
|
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
|
||||||
|
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
|
||||||
|
|
||||||
|
// Compute combined scale for the block
|
||||||
|
float scaleScalar = d0_0[0] * d1_0[0];
|
||||||
|
__m512 scale = _mm512_set1_ps( scaleScalar );
|
||||||
|
|
||||||
|
__m256i bx = bytesFromNibbles( p0 );
|
||||||
|
__m256i by = bytesFromNibbles( p1 );
|
||||||
|
|
||||||
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||||
|
const __m256i off = _mm256_set1_epi8( 8 );
|
||||||
|
bx = _mm256_sub_epi8( bx, off );
|
||||||
|
by = _mm256_sub_epi8( by, off );
|
||||||
|
|
||||||
|
// Sign-extend 16 signed bytes into int16_t
|
||||||
|
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
||||||
|
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
||||||
|
// Compute products of int16_t integers, add pairwise
|
||||||
|
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
||||||
|
|
||||||
|
// Convert int32_t to float
|
||||||
|
__m512 p = _mm512_cvtepi32_ps( i64 );
|
||||||
|
// Apply the scale, and accumulate
|
||||||
|
return _mm512_fmadd_ps( scale, p, acc );
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// method 5
|
// method 5
|
||||||
// blocks of QK elements
|
// blocks of QK elements
|
||||||
|
@ -1419,44 +1459,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__AVX512F__)
|
#elif defined(__AVX512F__)
|
||||||
|
|
||||||
inline __m512 process_one_block(
|
|
||||||
__m512 acc,
|
|
||||||
const uint8_t * pd0,
|
|
||||||
const uint8_t * pd1,
|
|
||||||
const uint8_t * pb0,
|
|
||||||
const uint8_t * pb1,
|
|
||||||
int i
|
|
||||||
) {
|
|
||||||
const float * d0_0 = (const float *) (pd0 + i*bs);
|
|
||||||
const float * d1_0 = (const float *) (pd1 + i*bs);
|
|
||||||
|
|
||||||
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
|
|
||||||
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
|
|
||||||
|
|
||||||
// Compute combined scale for the block
|
|
||||||
float scaleScalar = d0_0[0] * d1_0[0];
|
|
||||||
__m512 scale = _mm512_set1_ps( scaleScalar );
|
|
||||||
|
|
||||||
__m256i bx = bytesFromNibbles( p0 );
|
|
||||||
__m256i by = bytesFromNibbles( p1 );
|
|
||||||
|
|
||||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
|
||||||
const __m256i off = _mm256_set1_epi8( 8 );
|
|
||||||
bx = _mm256_sub_epi8( bx, off );
|
|
||||||
by = _mm256_sub_epi8( by, off );
|
|
||||||
|
|
||||||
// Sign-extend 16 signed bytes into int16_t
|
|
||||||
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
|
||||||
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
|
||||||
// Compute products of int16_t integers, add pairwise
|
|
||||||
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
|
||||||
|
|
||||||
// Convert int32_t to float
|
|
||||||
__m512 p = _mm512_cvtepi32_ps( i64 );
|
|
||||||
// Apply the scale, and accumulate
|
|
||||||
return _mm512_fmadd_ps( scale, p, acc );
|
|
||||||
}
|
|
||||||
|
|
||||||
#if QK == 32
|
#if QK == 32
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m512 acc0 = _mm512_setzero_ps();
|
__m512 acc0 = _mm512_setzero_ps();
|
||||||
|
@ -1469,19 +1471,19 @@ inline __m512 process_one_block(
|
||||||
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
||||||
int i = superblock_ix * superblock_size;
|
int i = superblock_ix * superblock_size;
|
||||||
|
|
||||||
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+0 );
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
|
||||||
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+1 );
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
|
||||||
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+2 );
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
|
||||||
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+3 );
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
|
||||||
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+4 );
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
|
||||||
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+5 );
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
|
||||||
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+6 );
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
|
||||||
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+7 );
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remainders
|
// Remainders
|
||||||
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
||||||
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i );
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
|
||||||
}
|
}
|
||||||
|
|
||||||
// Horizontal sum of all lanes of the accumulator
|
// Horizontal sum of all lanes of the accumulator
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue