ggml : fix AVX2 implementation

This commit is contained in:
Georgi Gerganov 2023-05-11 21:22:27 +03:00
parent bd5e373058
commit 5bc286ab18
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

106
ggml.c
View file

@ -473,23 +473,16 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
//
#if __AVX__ || __AVX2__ || __AVX512F__
// Unpack 16 4-bit fields into 16 bytes
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
{
// Load 8 bytes from memory
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m128i bytes = _mm_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i high = _mm_andnot_si128( lowMask, bytes );
__m128i low = _mm_and_si128( lowMask, bytes );
high = _mm_slli_epi16( high, 4 );
bytes = _mm_or_si128( low, high );
return bytes;
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(x, x);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(y, x);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
return _mm_madd_epi16(ones, dot);
}
// horizontally add 8 floats
@ -524,14 +517,21 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m256i shuf_mask = _mm256_set_epi64x(
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytes = _mm256_or_si256(bytes, bit_mask);
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
}
static inline __m256i bytes_from_nibbles_32_deinterleave(const uint8_t * rsi) {
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
const __m256i lowMask = _mm256_set1_epi8( 0xF );
return _mm256_and_si256(lowMask, bytes);
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@ -984,7 +984,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__AVX2__) || defined(__AVX__)
#elif defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
@ -1029,7 +1029,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
#if defined(__AVX2__) // || defined(__AVX__) TODO
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@ -1037,10 +1037,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
// TODO: find a smarter way to do this
i2 = _mm256_permute2f128_si256(i0, i0, 0x01);
i1 = _mm256_shuffle_epi8(i0, _mm256_setr_epi8( 0, 2,-1,-1, 4, 6,-1,-1, 8,10,-1,-1,12,14,-1,-1,-1,-1, 1, 3,-1,-1, 5, 7,-1,-1, 9,11,-1,-1,13,15));
i2 = _mm256_shuffle_epi8(i2, _mm256_setr_epi8(-1,-1, 0, 2,-1,-1, 4, 6,-1,-1, 8,10,-1,-1,12,14, 1, 3,-1,-1, 5, 7,-1,-1, 9,11,-1,-1,13,15,-1,-1));
i0 = _mm256_or_si256(i1, i2);
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
@ -1152,7 +1153,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
y[i].s = d * vaddvq_s32(accv);
}
#elif defined(__AVX2__) || defined(__AVX__)
#elif defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
@ -1197,7 +1198,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
#if defined(__AVX2__) // || defined(__AVX__) TODO
// Compute the sum of the quants and set y[i].s
y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
@ -1208,10 +1209,11 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
// TODO: find a smarter way to do this
i2 = _mm256_permute2f128_si256(i0, i0, 0x01);
i1 = _mm256_shuffle_epi8(i0, _mm256_setr_epi8( 0, 2,-1,-1, 4, 6,-1,-1, 8,10,-1,-1,12,14,-1,-1,-1,-1, 1, 3,-1,-1, 5, 7,-1,-1, 9,11,-1,-1,13,15));
i2 = _mm256_shuffle_epi8(i2, _mm256_setr_epi8(-1,-1, 0, 2,-1,-1, 4, 6,-1,-1, 8,10,-1,-1,12,14, 1, 3,-1,-1, 5, 7,-1,-1, 9,11,-1,-1,13,15,-1,-1));
i0 = _mm256_or_si256(i1, i2);
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
@ -2101,7 +2103,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bx = bytes_from_nibbles_32_deinterleave(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
@ -2125,31 +2127,24 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
__m128i bx = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
i32[j] = _mm_madd_epi16(ones, dot);
}
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
__m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
@ -2267,7 +2262,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bx = bytes_from_nibbles_32_deinterleave(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
@ -2471,7 +2466,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bx = bytes_from_nibbles_32_deinterleave(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
bx = _mm256_or_si256(bx, bxhi);
@ -2689,6 +2684,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
float summs = 0.0f;
// Main loop
@ -2697,7 +2693,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bx = bytes_from_nibbles_32_deinterleave(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
bx = _mm256_or_si256(bx, bxhi);