From c29ab90e06c0ed4a2c95d691d8a619e8a981b0d6 Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Sun, 16 Apr 2023 09:55:39 +0200 Subject: [PATCH] Q2 AVX2: do two blocks at a time, by @slaren --- ggml.c | 76 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/ggml.c b/ggml.c index 82876e50e..bfafaaa33 100644 --- a/ggml.c +++ b/ggml.c @@ -488,6 +488,34 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) } #if __AVX2__ || __AVX512F__ +// Unpack 32 2-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval +static inline __m256i bytes_from_crumbs(uint32_t packed_hi, uint32_t packed_lo) { + __m128i bx_hi = _mm_set1_epi32(packed_hi); + __m128i bx_lo = _mm_set1_epi32(packed_lo); + __m256i bx = _mm256_set_m128i(bx_hi, bx_lo); + + // shift counts to get all bit pairs in lowest position of each byte + const __m256i shift256 = _mm256_set_epi32(6, 4, 2, 0, + 6, 4, 2, 0); + bx = _mm256_srlv_epi32(bx, shift256); + + const __m256i shufmask = _mm256_set_epi8(15,11, 7, 3, + 14,10, 6, 2, + 13, 9, 5, 1, + 12, 8, 4, 0, + 15,11, 7, 3, + 14,10, 6, 2, + 13, 9, 5, 1, + 12, 8, 4, 0); + bx = _mm256_shuffle_epi8(bx, shufmask); + + const __m256i mask = _mm256_set1_epi8(3); + bx = _mm256_and_si256(mask, bx); + + return bx; +} + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) @@ -2500,6 +2528,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK2_0 == 0); const int nb = n / QK2_0; + assert(nb % 2 == 0); const block_q2_0 * restrict x = vx; const block_q8_0 * restrict y = vy; @@ -2508,49 +2537,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * #if defined(__AVX2__) // Initialize accumulator with zeros - __m128 acc = _mm_setzero_ps(); + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i += 2) { + __m256i bx = bytes_from_crumbs(x[i+1].qs, x[i].qs); - for (int i = 0; i < nb; i++) { // Compute combined scale for the block - const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d); + const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d); + const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d); + const __m256 scale = _mm256_set_m128(scale_hi, scale_lo); - __m128i bx = _mm_set1_epi32(x[i].qs); + const __m256i off = _mm256_set1_epi8(2); + bx = _mm256_sub_epi8(bx, off); - // shift counts to get all bit pairs in lowest position of each byte - const __m128i shift128 = _mm_set_epi32(6, 4, 2, 0); - bx = _mm_srlv_epi32(bx, shift128); - - const __m128i shufmask = _mm_set_epi8(15,11,7,3,14,10,6,2,13,9,5,1,12,8,4,0); - bx = _mm_shuffle_epi8(bx, shufmask); - - const __m128i mask = _mm_set1_epi8(3); - bx = _mm_and_si128(mask, bx); - - const __m128i off = _mm_set1_epi8(2); - bx = _mm_sub_epi8(bx, off); - - const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK2_0)); + // Load y vector + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs); // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(bx, bx); + const __m256i ax = _mm256_sign_epi8(bx, bx); // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(by, bx); + const __m256i sy = _mm256_sign_epi8(by, bx); // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m256i dot = _mm256_maddubs_epi16(ax, sy); // Convert int16_t to int32_t by adding pairwise - const __m128i ones = _mm_set1_epi16(1); - __m128i i32 = _mm_madd_epi16(dot, ones); + const __m256i ones = _mm256_set1_epi16(1); + __m256i i32 = _mm256_madd_epi16(ones, dot); // Convert int32_t to float - const __m128 p = _mm_cvtepi32_ps(i32); + __m256 p = _mm256_cvtepi32_ps(i32); // Apply the scale, and accumulate - acc = _mm_fmadd_ps(scale, p, acc); + acc = _mm256_fmadd_ps(scale, p, acc); } // Return horizontal sum of the acc vector - __m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc)); + __m128 res = _mm256_extractf128_ps(acc, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); res = _mm_add_ss(res, _mm_movehdup_ps(res)); sumf = _mm_cvtss_f32(res); #else