ggml : add AVX quantize_row_q4_0()
This commit is contained in:
parent
9cbc404ba6
commit
79e14129e1
1 changed files with 90 additions and 0 deletions
90
ggml.c
90
ggml.c
|
@ -461,6 +461,22 @@ static inline __m128i packNibbles( __m256i bytes )
|
||||||
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
||||||
return _mm_packus_epi16( r0, r1 );
|
return _mm_packus_epi16( r0, r1 );
|
||||||
}
|
}
|
||||||
|
#elif __AVX__
|
||||||
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||||
|
{
|
||||||
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||||
|
const __m128i lowByte = _mm_set1_epi16( 0xFF );
|
||||||
|
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
|
||||||
|
__m128i low = _mm_and_si128( lowByte, bytes1 );
|
||||||
|
high = _mm_srli_epi16( high, 4 );
|
||||||
|
bytes1 = _mm_or_si128( low, high );
|
||||||
|
high = _mm_andnot_si128( lowByte, bytes2 );
|
||||||
|
low = _mm_and_si128( lowByte, bytes2 );
|
||||||
|
high = _mm_srli_epi16( high, 4 );
|
||||||
|
bytes2 = _mm_or_si128( low, high );
|
||||||
|
|
||||||
|
return _mm_packus_epi16( bytes1, bytes2);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// method 5
|
// method 5
|
||||||
|
@ -660,6 +676,80 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
||||||
__m128i res = packNibbles( i0 );
|
__m128i res = packNibbles( i0 );
|
||||||
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
||||||
}
|
}
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
// Load elements into 4 AVX vectors
|
||||||
|
__m256 v0 = _mm256_loadu_ps( x );
|
||||||
|
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
||||||
|
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
||||||
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||||
|
x += 32;
|
||||||
|
|
||||||
|
// Compute max(abs(e)) for the block
|
||||||
|
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||||
|
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
||||||
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
||||||
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
||||||
|
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
||||||
|
|
||||||
|
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||||
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||||
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||||
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||||
|
|
||||||
|
// Quantize these floats
|
||||||
|
const float d = maxScalar / 7.0f;
|
||||||
|
y[i].d = d;
|
||||||
|
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
|
||||||
|
const __m256 mul = _mm256_set1_ps( id );
|
||||||
|
|
||||||
|
// Apply the multiplier
|
||||||
|
v0 = _mm256_mul_ps( v0, mul );
|
||||||
|
v1 = _mm256_mul_ps( v1, mul );
|
||||||
|
v2 = _mm256_mul_ps( v2, mul );
|
||||||
|
v3 = _mm256_mul_ps( v3, mul );
|
||||||
|
|
||||||
|
// Round to nearest integer
|
||||||
|
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||||
|
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||||
|
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||||
|
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||||
|
|
||||||
|
// Convert floats to integers
|
||||||
|
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||||
|
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||||
|
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||||
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||||
|
|
||||||
|
// Since we don't have in AVX some necessary functions,
|
||||||
|
// we split the registers in half and call AVX2 analogs from SSE
|
||||||
|
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
||||||
|
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
||||||
|
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
||||||
|
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
||||||
|
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
||||||
|
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
||||||
|
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
||||||
|
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
||||||
|
|
||||||
|
// Convert int32 to int16
|
||||||
|
ni0 = _mm_packs_epi32( ni0, ni1 );
|
||||||
|
ni2 = _mm_packs_epi32( ni2, ni3 );
|
||||||
|
ni4 = _mm_packs_epi32( ni4, ni5 );
|
||||||
|
ni6 = _mm_packs_epi32( ni6, ni7 );
|
||||||
|
// Convert int16 to int8
|
||||||
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
||||||
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
||||||
|
|
||||||
|
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
|
||||||
|
const __m128i off = _mm_set1_epi8( 8);
|
||||||
|
ni0 = _mm_add_epi8( ni0, off );
|
||||||
|
ni4 = _mm_add_epi8( ni4, off );
|
||||||
|
|
||||||
|
// Compress the vector into 4 bit/value, and store
|
||||||
|
__m128i res = packNibbles( ni0, ni4 );
|
||||||
|
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
||||||
|
}
|
||||||
#elif defined(__wasm_simd128__)
|
#elif defined(__wasm_simd128__)
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue