AVX implementations (#1370)
This commit is contained in:
parent
d155f0f865
commit
948d124837
2 changed files with 33 additions and 65 deletions
82
ggml.c
82
ggml.c
|
@ -472,23 +472,16 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|||
//
|
||||
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
// Unpack 16 4-bit fields into 16 bytes
|
||||
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
||||
{
|
||||
// Load 8 bytes from memory
|
||||
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
|
||||
|
||||
// Expand bytes into uint16_t values
|
||||
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
||||
|
||||
// Unpack values into individual bytes
|
||||
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
||||
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
||||
__m128i low = _mm_and_si128( lowMask, bytes );
|
||||
high = _mm_slli_epi16( high, 4 );
|
||||
bytes = _mm_or_si128( low, high );
|
||||
return bytes;
|
||||
// multiply int8_t, add results pairwise twice
|
||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||
// Get absolute values of x vectors
|
||||
const __m128i ax = _mm_sign_epi8(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m128i sy = _mm_sign_epi8(y, x);
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
||||
const __m128i ones = _mm_set1_epi16(1);
|
||||
return _mm_madd_epi16(ones, dot);
|
||||
}
|
||||
|
||||
// horizontally add 8 floats
|
||||
|
@ -535,19 +528,10 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
||||
{
|
||||
// Load 16 bytes from memory
|
||||
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
||||
|
||||
// Expand bytes into uint16_t values
|
||||
__m256i bytes = _mm256_cvtepu8_epi16( tmp );
|
||||
|
||||
// Unpack values into individual bytes
|
||||
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
|
||||
const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
|
||||
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
||||
__m256i high = _mm256_andnot_si256( lowMask, bytes );
|
||||
__m256i low = _mm256_and_si256( lowMask, bytes );
|
||||
high = _mm256_slli_epi16( high, 4 );
|
||||
bytes = _mm256_or_si256( low, high );
|
||||
return bytes;
|
||||
return _mm256_and_si256(lowMask, bytes);
|
||||
}
|
||||
|
||||
// add int16_t pairwise and return as float vector
|
||||
|
@ -2146,31 +2130,23 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
// Compute combined scale for the block
|
||||
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
||||
|
||||
__m128i i32[2];
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
||||
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
||||
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
const __m128i off = _mm_set1_epi8(8);
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
const __m128i off = _mm_set1_epi8( 8 );
|
||||
bx = _mm_sub_epi8( bx, off );
|
||||
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
|
||||
// Get absolute values of x vectors
|
||||
const __m128i ax = _mm_sign_epi8(bx, bx);
|
||||
__m128i bx = _mm_and_si128(lowMask, tmp);
|
||||
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
|
||||
|
||||
// Sign the values of the y vectors
|
||||
const __m128i sy = _mm_sign_epi8(by, bx);
|
||||
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
||||
|
||||
const __m128i ones = _mm_set1_epi16(1);
|
||||
i32[j] = _mm_madd_epi16(ones, dot);
|
||||
}
|
||||
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
|
||||
|
||||
// Convert int32_t to float
|
||||
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
||||
__m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
|
||||
// Apply the scale, and accumulate
|
||||
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
||||
}
|
||||
|
@ -2484,8 +2460,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|||
int sumi = 0;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
||||
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
||||
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
||||
|
||||
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
|
||||
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
|
||||
|
@ -2673,8 +2649,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
int sumi = 0;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
||||
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
||||
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
||||
|
||||
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
|
||||
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue