AVX implementations (#1370)

This commit is contained in:
Stephan Walter 2023-05-08 19:14:06 +00:00 committed by GitHub
parent d155f0f865
commit 948d124837
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 65 deletions

82
ggml.c
View file

@ -472,23 +472,16 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
//
#if __AVX__ || __AVX2__ || __AVX512F__
// Unpack 16 4-bit fields into 16 bytes
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
{
// Load 8 bytes from memory
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m128i bytes = _mm_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i high = _mm_andnot_si128( lowMask, bytes );
__m128i low = _mm_and_si128( lowMask, bytes );
high = _mm_slli_epi16( high, 4 );
bytes = _mm_or_si128( low, high );
return bytes;
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(x, x);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(y, x);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
return _mm_madd_epi16(ones, dot);
}
// horizontally add 8 floats
@ -535,19 +528,10 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m256i bytes = _mm256_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
const __m256i lowMask = _mm256_set1_epi8( 0xF );
__m256i high = _mm256_andnot_si256( lowMask, bytes );
__m256i low = _mm256_and_si256( lowMask, bytes );
high = _mm256_slli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
return bytes;
return _mm256_and_si256(lowMask, bytes);
}
// add int16_t pairwise and return as float vector
@ -2146,31 +2130,23 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
__m128i bx = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
i32[j] = _mm_madd_epi16(ones, dot);
}
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
__m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
@ -2484,8 +2460,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
@ -2673,8 +2649,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;