From 61a30466300be5ab5b4cc7ef0987de8acf7e4385 Mon Sep 17 00:00:00 2001 From: katsu560 Date: Sun, 14 May 2023 04:59:01 +0900 Subject: [PATCH] ggml : add AVX support to quantize_row_q5_0, quantize_row_q5_1 --- ggml.c | 239 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) diff --git a/ggml.c b/ggml.c index 24c79a9af..57bd37559 100644 --- a/ggml.c +++ b/ggml.c @@ -908,7 +908,135 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r } static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + +#if defined(__AVX__) + const int nb = k / qk; + + block_q5_0 * restrict yy = y; + + const __m256 signBit8 = _mm256_set1_ps( -0.0f ); + const __m128 signBit4 = _mm_set1_ps( -0.0f ); + const __m256 base = _mm256_set1_ps( 16.5f ); + const __m128i n31 = _mm_set1_epi8( 31 ); + const __m128i lowmask = _mm_set1_epi8( 0xF ); + const __m128i bit5mask = _mm_set1_epi8( 0x10 ); + + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(e) by max(abs(e)) for the block + __m256 abs0 = _mm256_andnot_ps( signBit8, v0 ); + __m256 abs1 = _mm256_andnot_ps( signBit8, v1 ); + __m256 mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); + __m256 max01 = _mm256_blendv_ps( v0, v1, mask8 ); + + abs0 = _mm256_andnot_ps( signBit8, v2 ); + abs1 = _mm256_andnot_ps( signBit8, v3 ); + mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); + __m256 max23 = _mm256_blendv_ps( v2, v3, mask8 ); + + abs0 = _mm256_andnot_ps( signBit8, max01 ); + abs1 = _mm256_andnot_ps( signBit8, max23 ); + mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); + max01 = _mm256_blendv_ps( max01, max23, mask8 ); + + __m128 lo = _mm256_castps256_ps128( max01 ); + __m128 hi = _mm256_extractf128_ps( max01, 1 ); + __m128 abslo = _mm_andnot_ps( signBit4, lo ); + __m128 abshi = _mm_andnot_ps( signBit4, hi ); + __m128 mask4 = _mm_cmp_ps( abslo, abshi, _CMP_LE_OQ); + __m128 maxhl = _mm_blendv_ps( lo, hi, mask4 ); + + hi = _mm_movehl_ps( maxhl, maxhl ); + abslo = _mm_andnot_ps( signBit4, maxhl ); + abshi = _mm_andnot_ps( signBit4, hi ); + mask4 = _mm_cmp_ps( abslo, abshi, _CMP_LE_OQ); + maxhl = _mm_blendv_ps( lo, hi, mask4 ); + + hi = _mm_movehdup_ps( maxhl ); + abslo = _mm_andnot_ps( signBit4, maxhl ); + abshi = _mm_andnot_ps( signBit4, hi ); + mask4 = _mm_cmp_ps( abshi, abslo, _CMP_LE_OQ); + maxhl = _mm_blendv_ps( abshi, abslo, mask4 ); + const float max = _mm_cvtss_f32( maxhl ); + + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + yy[i].d = GGML_FP32_TO_FP16(d); + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Add 16.5f + v0 = _mm256_add_ps( v0, base ); + v1 = _mm256_add_ps( v1, base ); + v2 = _mm256_add_ps( v2, base ); + v3 = _mm256_add_ps( v3, base ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1 ); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1 ); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1 ); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1 ); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + ni0 = _mm_min_epi8( n31, ni0 ); + ni4 = _mm_min_epi8( n31, ni4 ); + + // y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + ni1 = _mm_and_si128( lowmask, ni0 ); + ni5 = _mm_and_si128( lowmask, ni4 ); + ni5 = _mm_slli_epi16( ni5, 4 ); + ni1 = _mm_or_si128( ni1, ni5 ); + _mm_storeu_si128((__m128i *)(yy[i].qs + 0), ni1); + + // get the 5-th bit and store it in qh at the right position + // qh |= ((xi0 & 0x10) >> 4) << (j + 0); + // qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + ni0 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni0 ), 3 ); + ni4 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni4 ), 3 ); + uint16_t qhl = _mm_movemask_epi8( ni0 ); + uint16_t qhh = _mm_movemask_epi8( ni4 ); + memcpy(&yy[i].qh[0], &qhl, sizeof(qhl)); + memcpy(&yy[i].qh[2], &qhh, sizeof(qhh)); + } +#else quantize_row_q5_0_reference(x, y, k); +#endif } static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { @@ -956,7 +1084,118 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r } static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { + const int qk = QK5_1; + + assert(k % qk == 0); + +#if defined(__AVX__) + const int nb = k / qk; + + block_q5_1 * restrict yy = y; + + const __m256 base = _mm256_set1_ps( 0.5f ); + const __m128i lowmask = _mm_set1_epi8( 0xF ); + const __m128i bit5mask = _mm_set1_epi8( 0x10 ); + + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max,min + __m256 max8 = _mm256_max_ps( v0, v1 ); + max8 = _mm256_max_ps( max8, v2 ); + max8 = _mm256_max_ps( max8, v3 ); + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max8, 1 ), _mm256_castps256_ps128( max8 ) ); + max4 = _mm_max_ps( _mm_movehl_ps( max4, max4 ), max4 ); + max4 = _mm_max_ss( _mm_movehdup_ps( max4 ), max4 ); + const float max = _mm_cvtss_f32( max4 ); + + __m256 min8 = _mm256_min_ps( v0, v1 ); + min8 = _mm256_min_ps( min8, v2 ); + min8 = _mm256_min_ps( min8, v3 ); + __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min8, 1 ), _mm256_castps256_ps128( min8 ) ); + min4 = _mm_min_ps( _mm_movehl_ps( min4, min4 ), min4 ); + min4 = _mm_min_ss( _mm_movehdup_ps( min4 ), min4 ); + const float min = _mm_cvtss_f32( min4 ); + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + yy[i].d = GGML_FP32_TO_FP16(d); + yy[i].m = GGML_FP32_TO_FP16(min); + + const __m256 mul = _mm256_set1_ps( id ); + + // Subtract min + min8 = _mm256_set1_ps( min ); + v0 = _mm256_sub_ps( v0, min8 ); + v1 = _mm256_sub_ps( v1, min8 ); + v2 = _mm256_sub_ps( v2, min8 ); + v3 = _mm256_sub_ps( v3, min8 ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Add 0.5f + v0 = _mm256_add_ps( v0, base ); + v1 = _mm256_add_ps( v1, base ); + v2 = _mm256_add_ps( v2, base ); + v3 = _mm256_add_ps( v3, base ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1 ); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1 ); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1 ); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1 ); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + // y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + ni1 = _mm_and_si128( lowmask, ni0 ); + ni5 = _mm_and_si128( lowmask, ni4 ); + ni5 = _mm_slli_epi16( ni5, 4 ); + ni1 = _mm_or_si128( ni1, ni5 ); + _mm_storeu_si128((__m128i *)(yy[i].qs + 0), ni1); + + // get the 5-th bit and store it in qh at the right position + // qh |= ((xi0 & 0x10) >> 4) << (j + 0); + // qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + ni0 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni0 ), 3 ); + ni4 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni4 ), 3 ); + uint16_t qhl = _mm_movemask_epi8( ni0 ); + uint16_t qhh = _mm_movemask_epi8( ni4 ); + memcpy(&yy[i].qh[0], &qhl, sizeof(qhl)); + memcpy(&yy[i].qh[2], &qhh, sizeof(qhh)); + } +#else quantize_row_q5_1_reference(x, y, k); +#endif } // reference implementation for deterministic creation of model files