diff --git a/ggml.c b/ggml.c index d1b4cbdca..b4dac6223 100644 --- a/ggml.c +++ b/ggml.c @@ -908,135 +908,7 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r } static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - -#if defined(__AVX__) - const int nb = k / qk; - - block_q5_0 * restrict yy = y; - - const __m256 signBit8 = _mm256_set1_ps( -0.0f ); - const __m128 signBit4 = _mm_set1_ps( -0.0f ); - const __m256 base = _mm256_set1_ps( 16.5f ); - const __m128i n31 = _mm_set1_epi8( 31 ); - const __m128i lowmask = _mm_set1_epi8( 0xF ); - const __m128i bit5mask = _mm_set1_epi8( 0x10 ); - - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(e) by max(abs(e)) for the block - __m256 abs0 = _mm256_andnot_ps( signBit8, v0 ); - __m256 abs1 = _mm256_andnot_ps( signBit8, v1 ); - __m256 mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); - __m256 max01 = _mm256_blendv_ps( v0, v1, mask8 ); - - abs0 = _mm256_andnot_ps( signBit8, v2 ); - abs1 = _mm256_andnot_ps( signBit8, v3 ); - mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); - __m256 max23 = _mm256_blendv_ps( v2, v3, mask8 ); - - abs0 = _mm256_andnot_ps( signBit8, max01 ); - abs1 = _mm256_andnot_ps( signBit8, max23 ); - mask8 = _mm256_cmp_ps( abs0, abs1, _CMP_LE_OQ); - max01 = _mm256_blendv_ps( max01, max23, mask8 ); - - __m128 lo = _mm256_castps256_ps128( max01 ); - __m128 hi = _mm256_extractf128_ps( max01, 1 ); - __m128 abslo = _mm_andnot_ps( signBit4, lo ); - __m128 abshi = _mm_andnot_ps( signBit4, hi ); - __m128 mask4 = _mm_cmp_ps( abslo, abshi, _CMP_LE_OQ); - __m128 maxhl = _mm_blendv_ps( lo, hi, mask4 ); - - hi = _mm_movehl_ps( maxhl, maxhl ); - abslo = _mm_andnot_ps( signBit4, maxhl ); - abshi = _mm_andnot_ps( signBit4, hi ); - mask4 = _mm_cmp_ps( abslo, abshi, _CMP_LE_OQ); - maxhl = _mm_blendv_ps( lo, hi, mask4 ); - - hi = _mm_movehdup_ps( maxhl ); - abslo = _mm_andnot_ps( signBit4, maxhl ); - abshi = _mm_andnot_ps( signBit4, hi ); - mask4 = _mm_cmp_ps( abshi, abslo, _CMP_LE_OQ); - maxhl = _mm_blendv_ps( abshi, abslo, mask4 ); - const float max = _mm_cvtss_f32( maxhl ); - - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - yy[i].d = GGML_FP32_TO_FP16(d); - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Add 16.5f - v0 = _mm256_add_ps( v0, base ); - v1 = _mm256_add_ps( v1, base ); - v2 = _mm256_add_ps( v2, base ); - v3 = _mm256_add_ps( v3, base ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1 ); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1 ); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1 ); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1 ); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - ni0 = _mm_min_epi8( n31, ni0 ); - ni4 = _mm_min_epi8( n31, ni4 ); - - // y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - ni1 = _mm_and_si128( lowmask, ni0 ); - ni5 = _mm_and_si128( lowmask, ni4 ); - ni5 = _mm_slli_epi16( ni5, 4 ); - ni1 = _mm_or_si128( ni1, ni5 ); - _mm_storeu_si128((__m128i *)(yy[i].qs + 0), ni1); - - // get the 5-th bit and store it in qh at the right position - // qh |= ((xi0 & 0x10) >> 4) << (j + 0); - // qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - ni0 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni0 ), 3 ); - ni4 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni4 ), 3 ); - uint16_t qhl = _mm_movemask_epi8( ni0 ); - uint16_t qhh = _mm_movemask_epi8( ni4 ); - memcpy(&yy[i].qh[0], &qhl, sizeof(qhl)); - memcpy(&yy[i].qh[2], &qhh, sizeof(qhh)); - } -#else quantize_row_q5_0_reference(x, y, k); -#endif } static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { @@ -1084,118 +956,7 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r } static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { - const int qk = QK5_1; - - assert(k % qk == 0); - -#if defined(__AVX__) - const int nb = k / qk; - - block_q5_1 * restrict yy = y; - - const __m256 base = _mm256_set1_ps( 0.5f ); - const __m128i lowmask = _mm_set1_epi8( 0xF ); - const __m128i bit5mask = _mm_set1_epi8( 0x10 ); - - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max,min - __m256 max8 = _mm256_max_ps( v0, v1 ); - max8 = _mm256_max_ps( max8, v2 ); - max8 = _mm256_max_ps( max8, v3 ); - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max8, 1 ), _mm256_castps256_ps128( max8 ) ); - max4 = _mm_max_ps( _mm_movehl_ps( max4, max4 ), max4 ); - max4 = _mm_max_ss( _mm_movehdup_ps( max4 ), max4 ); - const float max = _mm_cvtss_f32( max4 ); - - __m256 min8 = _mm256_min_ps( v0, v1 ); - min8 = _mm256_min_ps( min8, v2 ); - min8 = _mm256_min_ps( min8, v3 ); - __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min8, 1 ), _mm256_castps256_ps128( min8 ) ); - min4 = _mm_min_ps( _mm_movehl_ps( min4, min4 ), min4 ); - min4 = _mm_min_ss( _mm_movehdup_ps( min4 ), min4 ); - const float min = _mm_cvtss_f32( min4 ); - - const float d = (max - min) / ((1 << 5) - 1); - const float id = d ? 1.0f/d : 0.0f; - - yy[i].d = GGML_FP32_TO_FP16(d); - yy[i].m = GGML_FP32_TO_FP16(min); - - const __m256 mul = _mm256_set1_ps( id ); - - // Subtract min - min8 = _mm256_set1_ps( min ); - v0 = _mm256_sub_ps( v0, min8 ); - v1 = _mm256_sub_ps( v1, min8 ); - v2 = _mm256_sub_ps( v2, min8 ); - v3 = _mm256_sub_ps( v3, min8 ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Add 0.5f - v0 = _mm256_add_ps( v0, base ); - v1 = _mm256_add_ps( v1, base ); - v2 = _mm256_add_ps( v2, base ); - v3 = _mm256_add_ps( v3, base ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1 ); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1 ); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1 ); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1 ); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - // y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - ni1 = _mm_and_si128( lowmask, ni0 ); - ni5 = _mm_and_si128( lowmask, ni4 ); - ni5 = _mm_slli_epi16( ni5, 4 ); - ni1 = _mm_or_si128( ni1, ni5 ); - _mm_storeu_si128((__m128i *)(yy[i].qs + 0), ni1); - - // get the 5-th bit and store it in qh at the right position - // qh |= ((xi0 & 0x10) >> 4) << (j + 0); - // qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - ni0 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni0 ), 3 ); - ni4 = _mm_slli_epi16( _mm_and_si128( bit5mask, ni4 ), 3 ); - uint16_t qhl = _mm_movemask_epi8( ni0 ); - uint16_t qhh = _mm_movemask_epi8( ni4 ); - memcpy(&yy[i].qh[0], &qhl, sizeof(qhl)); - memcpy(&yy[i].qh[2], &qhh, sizeof(qhh)); - } -#else quantize_row_q5_1_reference(x, y, k); -#endif } // reference implementation for deterministic creation of model files