From db8d0f17981ce21820817da8f546073fe7a8c367 Mon Sep 17 00:00:00 2001 From: Casey Primozic Date: Tue, 21 Mar 2023 01:38:38 -0700 Subject: [PATCH] Move AVX512 dot product block helper closer to caller --- ggml.c | 82 +++++++++++++++++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/ggml.c b/ggml.c index 62a87485c..f85138f38 100644 --- a/ggml.c +++ b/ggml.c @@ -397,47 +397,6 @@ static inline __m128i packNibbles( __m256i bytes ) } #endif -#if __AVX512F__ && QK == 32 -static inline __m512 dot_q4_0_oneblock_avx512( - __m512 acc, - const uint8_t * pd0, - const uint8_t * pd1, - const uint8_t * pb0, - const uint8_t * pb1, - size_t bs, - int i -) { - const float * d0_0 = (const float *) (pd0 + i*bs); - const float * d1_0 = (const float *) (pd1 + i*bs); - - const uint8_t * restrict p0 = pb0 + (i+0)*bs; - const uint8_t * restrict p1 = pb1 + (i+0)*bs; - - // Compute combined scale for the block - float scaleScalar = d0_0[0] * d1_0[0]; - __m512 scale = _mm512_set1_ps( scaleScalar ); - - __m256i bx = bytesFromNibbles( p0 ); - __m256i by = bytesFromNibbles( p1 ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - by = _mm256_sub_epi8( by, off ); - - // Sign-extend 16 signed bytes into int16_t - __m512i x32 = _mm512_cvtepi8_epi16( bx ); - __m512i y32 = _mm512_cvtepi8_epi16( by ); - // Compute products of int16_t integers, add pairwise - __m512i i64 = _mm512_madd_epi16( x32, y32 ); - - // Convert int32_t to float - __m512 p = _mm512_cvtepi32_ps( i64 ); - // Apply the scale, and accumulate - return _mm512_fmadd_ps( scale, p, acc ); -} -#endif - // method 5 // blocks of QK elements // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) @@ -1302,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float *s = sumf; } +#if __AVX512F__ && QK == 32 +static inline __m512 dot_q4_0_oneblock_avx512( + __m512 acc, + const uint8_t * pd0, + const uint8_t * pd1, + const uint8_t * pb0, + const uint8_t * pb1, + size_t bs, + int i +) { + const float * d0_0 = (const float *) (pd0 + i*bs); + const float * d1_0 = (const float *) (pd1 + i*bs); + + const uint8_t * restrict p0 = pb0 + (i+0)*bs; + const uint8_t * restrict p1 = pb1 + (i+0)*bs; + + // Compute combined scale for the block + float scaleScalar = d0_0[0] * d1_0[0]; + __m512 scale = _mm512_set1_ps( scaleScalar ); + + __m256i bx = bytesFromNibbles( p0 ); + __m256i by = bytesFromNibbles( p1 ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + by = _mm256_sub_epi8( by, off ); + + // Sign-extend 16 signed bytes into int16_t + __m512i x32 = _mm512_cvtepi8_epi16( bx ); + __m512i y32 = _mm512_cvtepi8_epi16( by ); + // Compute products of int16_t integers, add pairwise + __m512i i64 = _mm512_madd_epi16( x32, y32 ); + + // Convert int32_t to float + __m512 p = _mm512_cvtepi32_ps( i64 ); + // Apply the scale, and accumulate + return _mm512_fmadd_ps( scale, p, acc ); +} +#endif + inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0;