From ad051ac56ecf15b86da7d228d90b802da98b77be Mon Sep 17 00:00:00 2001 From: Matvey Soloviev Date: Wed, 15 Mar 2023 01:29:36 +0100 Subject: [PATCH] Small optimisations to q4_1 dot product (@Const-me) --- ggml.c | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index 795ed4eff..8b70806aa 100644 --- a/ggml.c +++ b/ggml.c @@ -1613,12 +1613,19 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void const uint8_t * restrict p0 = pb0 + i*QK/2; const uint8_t * restrict p1 = pb1 + i*QK/2; + const __m256 d0v = _mm256_broadcast_ss( d0 ); + const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 m0v = _mm256_broadcast_ss( m0 ); + const __m256 m1v = _mm256_broadcast_ss( m1 ); + + // Compute combined scale for the block - const __m256 scale_01 = _mm256_mul_ps( _mm256_broadcast_ss( d0 ), _mm256_broadcast_ss( d1 ) ); + const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); // Compute cross scales for the block - const __m256 scale_0 = _mm256_mul_ps( _mm256_broadcast_ss( d0 ), _mm256_broadcast_ss( m1 ) ); - const __m256 scale_1 = _mm256_mul_ps( _mm256_broadcast_ss( m0 ), _mm256_broadcast_ss( d1 ) ); + const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); + const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); + const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes __m256i bx = bytesFromNibbles( p0 ); @@ -1639,20 +1646,22 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); // compute sums of unsigned bytes in bx, by in blocks of 8. - // This results in a layout like S100 0000 S200 0000 S300 0000 S400 0000, - // so if we then cast to 8 singles, we get 8 floats like [ s0_7, 0.0, s8_15, 0.0, s16_23, 0.0, s24_31, 0.0 ] - __m256 xsum = _mm256_cvtepi32_ps( _mm256_sad_epu8( bx, _mm256_setzero_si256() ) ); - __m256 ysum = _mm256_cvtepi32_ps( _mm256_sad_epu8( by, _mm256_setzero_si256() ) ); + // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, + // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. + // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] + __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); + __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); + __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); + __m256 sums = _mm256_cvtepi32_ps( sumsi ); // Convert int32_t to float __m256 p = _mm256_cvtepi32_ps( i32 ); // Apply the scale, and accumulate // acc += d0*d1*x*y + d0*m1*x + d1*m0*y acc = _mm256_fmadd_ps( scale_01, p, acc ); - acc = _mm256_fmadd_ps( scale_0, xsum, acc ); - acc = _mm256_fmadd_ps( scale_1, ysum, acc ); + acc = _mm256_fmadd_ps( cross_scales, sums, acc ); // acc_offset += m0*m1 (for each entry in the block) - acc_offset += (*m0)*(*m1)*QK; + acc_offset += (*m0)*(*m1); } // Return horizontal sum of the acc vector @@ -1661,7 +1670,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - sumf = _mm_cvtss_f32( res ) + acc_offset; + sumf = _mm_cvtss_f32( res ) + acc_offset * QK; #else #error "not implemented for QK" #endif