ggml: add AVX support

This commit is contained in:
katsu560 2023-05-05 23:39:07 +09:00
parent 799fdc1b5d
commit 60196ae73d

376
ggml.c
View file

@ -588,6 +588,93 @@ static inline __m128i packNibbles( __m256i bytes )
#endif #endif
} }
#else #else
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytesl = _mm_or_si128(bytesl, bit_mask);
bytesh = _mm_or_si128(bytesh, bit_mask);
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
return _mm256_set_m128i(bytesh, bytesl);
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m128i bytesl = _mm_cvtepu8_epi16( tmp );
__m128i bytesh = _mm_unpackhi_epi8( tmp, _mm_setzero_si128() );
// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i highl = _mm_andnot_si128( lowMask, bytesl );
__m128i highh = _mm_andnot_si128( lowMask, bytesh );
const __m128i lowl = _mm_and_si128( lowMask, bytesl );
const __m128i lowh = _mm_and_si128( lowMask, bytesh );
highl = _mm_slli_epi16( highl, 4 );
highh = _mm_slli_epi16( highh, 4 );
bytesl = _mm_or_si128( lowl, highl );
bytesh = _mm_or_si128( lowh, highh );
__m256i bytes = _mm256_set_m128i(bytesh, bytesl);
return bytes;
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
const __m128i ones = _mm_set1_epi16(1);
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
return _mm256_cvtepi32_ps(summed_pairs);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
const __m128i xh = _mm256_extractf128_si256(x, 1);
const __m128i yl = _mm256_castsi256_si128(y);
const __m128i yh = _mm256_extractf128_si256(y, 1);
// Get absolute values of x vectors
const __m128i axl = _mm_sign_epi8(xl, xl);
const __m128i axh = _mm_sign_epi8(xh, xh);
// Sign the values of the y vectors
const __m128i syl = _mm_sign_epi8(yl, xl);
const __m128i syh = _mm_sign_epi8(yh, xh);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}
static inline __m128i packNibbles_256( __m256i bytes )
{
__m128i bytesl = _mm256_castsi256_si128(bytes);
__m128i bytesh = _mm256_extractf128_si256(bytes, 1);
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m128i lowByte = _mm_set1_epi16( 0xFF );
__m128i highl = _mm_andnot_si128( lowByte, bytesl );
__m128i highh = _mm_andnot_si128( lowByte, bytesh );
const __m128i lowl = _mm_and_si128( lowByte, bytesl );
const __m128i lowh = _mm_and_si128( lowByte, bytesh );
highl = _mm_srli_epi16( highl, 4 );
highh = _mm_srli_epi16( highh, 4 );
bytesl = _mm_or_si128( lowl, highl );
bytesh = _mm_or_si128( lowh, highh );
// Compress uint16_t lanes into bytes
return _mm_packus_epi16( bytesl, bytesh );
}
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{ {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@ -1096,23 +1183,23 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
// Since we don't have in AVX some necessary functions, // Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE // we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 ); __m128i ni0 = _mm256_castsi256_si128( i0 ); // 0,1,2,3
__m128i ni1 = _mm256_extractf128_si256( i0, 1); __m128i ni1 = _mm256_extractf128_si256( i0, 1); // 4,5,6,7
__m128i ni2 = _mm256_castsi256_si128( i1 ); __m128i ni2 = _mm256_castsi256_si128( i1 ); // 8,9,10,11
__m128i ni3 = _mm256_extractf128_si256( i1, 1); __m128i ni3 = _mm256_extractf128_si256( i1, 1); // 12,13,14,15
__m128i ni4 = _mm256_castsi256_si128( i2 ); __m128i ni4 = _mm256_castsi256_si128( i2 ); // 16,17,18,19
__m128i ni5 = _mm256_extractf128_si256( i2, 1); __m128i ni5 = _mm256_extractf128_si256( i2, 1); // 20,21,22,23
__m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni6 = _mm256_castsi256_si128( i3 ); // 24,25,26,27
__m128i ni7 = _mm256_extractf128_si256( i3, 1); __m128i ni7 = _mm256_extractf128_si256( i3, 1); // 28,29,30,31
// Convert int32 to int16 // Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 ); ni0 = _mm_packs_epi32( ni0, ni1 ); // 0,1,2,3, 4,5,6,7
ni2 = _mm_packs_epi32( ni2, ni3 ); ni2 = _mm_packs_epi32( ni2, ni3 ); // 8,9,10,11, 12,13,14,15
ni4 = _mm_packs_epi32( ni4, ni5 ); ni4 = _mm_packs_epi32( ni4, ni5 ); // 16,17,18,19, 20,21,22,23
ni6 = _mm_packs_epi32( ni6, ni7 ); ni6 = _mm_packs_epi32( ni6, ni7 ); // 24,25,26,27, 28,29,30,31
// Convert int16 to int8 // Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 ); ni0 = _mm_packs_epi16( ni0, ni2 ); // 0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15
ni4 = _mm_packs_epi16( ni4, ni6 ); ni4 = _mm_packs_epi16( ni4, ni6 ); // 16,17,18,19, 20,21,22,23, 24,25,26,27, 28,29,30,31
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ] // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m128i off = _mm_set1_epi8( 8 ); const __m128i off = _mm_set1_epi8( 8 );
@ -1222,7 +1309,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
block_q4_1 * restrict y = vy; block_q4_1 * restrict y = vy;
#if defined(__AVX2__) #if defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors // Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x ); __m256 v0 = _mm256_loadu_ps( x );
@ -1280,10 +1367,11 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
__m256i i2 = _mm256_cvtps_epi32( v2 ); __m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 ); __m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
// Convert int32 to int16 // Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8 // Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong // We got our precious signed bytes, but the order is now wrong
@ -1294,6 +1382,30 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 ); __m128i res = packNibbles( i0 );
#elif defined(__AVX__)
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 ); // 0,1,2,3
__m128i ni1 = _mm256_extractf128_si256( i0, 1); // 4,5,6,7
__m128i ni2 = _mm256_castsi256_si128( i1 ); // 8,9,10,11
__m128i ni3 = _mm256_extractf128_si256( i1, 1); // 12,13,14,15
__m128i ni4 = _mm256_castsi256_si128( i2 ); // 16,17,18,19
__m128i ni5 = _mm256_extractf128_si256( i2, 1); // 20,21,22,23
__m128i ni6 = _mm256_castsi256_si128( i3 ); // 24,25,26,27
__m128i ni7 = _mm256_extractf128_si256( i3, 1); // 28,29,30,31
// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 ); // 0,1,2,3, 4,5,6,7
ni2 = _mm_packs_epi32( ni2, ni3 ); // 8,9,10,11, 12,13,14,15
ni4 = _mm_packs_epi32( ni4, ni5 ); // 16,17,18,19, 20,21,22,23
ni6 = _mm_packs_epi32( ni6, ni7 ); // 24,25,26,27, 28,29,30,31
// Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 ); // 0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15
ni4 = _mm_packs_epi16( ni4, ni6 ); // 16,17,18,19, 20,21,22,23, 24,25,26,27, 28,29,30,31
// Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( ni0, ni4 );
#endif
_mm_storeu_si128( ( __m128i* )y[i].qs, res ); _mm_storeu_si128( ( __m128i* )y[i].qs, res );
} }
#elif __ARM_NEON #elif __ARM_NEON
@ -1874,6 +1986,48 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
} }
} }
} }
#elif defined(__AVX__)
for (int i = 0; i < nb; i++) {
// scale factor
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_0; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
__m128i vx8_lo = _mm256_castsi256_si128(vx8);
__m128i vx8_hi = _mm256_extractf128_si256(vx8, 1);
// Subtract 8 from the integers
vx8_lo = _mm_sub_epi8(vx8_lo, _mm_set1_epi8(8));
vx8_hi = _mm_sub_epi8(vx8_hi, _mm_set1_epi8(8));
// Convert to 16-bit int
__m128i vx16_lol = _mm_cvtepi8_epi16(vx8_lo);
__m128i vx16_loh = _mm_cvtepi8_epi16(_mm_srli_si128(vx8_lo, 8));
__m128i vx16_hil = _mm_cvtepi8_epi16(vx8_hi);
__m128i vx16_hih = _mm_cvtepi8_epi16(_mm_srli_si128(vx8_hi, 8));
// Convert to 32-bit int -> float 32
const __m256 vf[4] = {
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_lol, 8)), _mm_cvtepi16_epi32(vx16_lol))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_loh, 8)), _mm_cvtepi16_epi32(vx16_loh))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_hil, 8)), _mm_cvtepi16_epi32(vx16_hil))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_hih, 8)), _mm_cvtepi16_epi32(vx16_hih)))
};
// Scale and store
__m256 result = _mm256_mul_ps(vf[0], d_v);
_mm256_storeu_ps(y + i * QK4_0 + l + 0, result);
result = _mm256_mul_ps(vf[1], d_v);
_mm256_storeu_ps(y + i * QK4_0 + l + 8, result);
result = _mm256_mul_ps(vf[2], d_v);
_mm256_storeu_ps(y + i * QK4_0 + l + 16, result);
result = _mm256_mul_ps(vf[3], d_v);
_mm256_storeu_ps(y + i * QK4_0 + l + 24, result);
}
}
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float32x4_t vd = vdupq_n_f32(x[i].d); const float32x4_t vd = vdupq_n_f32(x[i].d);
@ -1983,10 +2137,52 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
}; };
// Scale, add m and store // Scale, add m and store
for (int j = 0; j < 4; j++) { __m256 result = _mm256_fmadd_ps(vf[0], d_v, d_m);
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); _mm256_storeu_ps(y + i * QK4_1 + l + 0, result);
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result); result = _mm256_fmadd_ps(vf[1], d_v, d_m);
} _mm256_storeu_ps(y + i * QK4_1 + l + 8, result);
result = _mm256_fmadd_ps(vf[2], d_v, d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 16, result);
result = _mm256_fmadd_ps(vf[3], d_v, d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 24, result);
}
}
#elif defined(__AVX__)
for (int i = 0; i < nb; i++) {
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_1; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
const __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
const __m128i vx8_lo = _mm256_castsi256_si128(vx8);
const __m128i vx8_hi = _mm256_extractf128_si256(vx8, 1);
// Convert to 16-bit int
const __m128i vx16_lol = _mm_cvtepi8_epi16(vx8_lo);
const __m128i vx16_loh = _mm_cvtepi8_epi16(_mm_srli_si128(vx8_lo, 8));
const __m128i vx16_hil = _mm_cvtepi8_epi16(vx8_hi);
const __m128i vx16_hih = _mm_cvtepi8_epi16(_mm_srli_si128(vx8_hi, 8));
// Convert to 32-bit int -> float 32
const __m256 vf[4] = {
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_lol, 8)), _mm_cvtepi16_epi32(vx16_lol))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_loh, 8)), _mm_cvtepi16_epi32(vx16_loh))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_hil, 8)), _mm_cvtepi16_epi32(vx16_hil))),
_mm256_cvtepi32_ps(_mm256_set_m128i(_mm_cvtepi16_epi32(_mm_srli_si128(vx16_hih, 8)), _mm_cvtepi16_epi32(vx16_hih)))
};
// Scale, add m and store
__m256 result = _mm256_add_ps(_mm256_mul_ps(vf[0], d_v), d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 0, result);
result = _mm256_add_ps(_mm256_mul_ps(vf[1], d_v), d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 8, result);
result = _mm256_add_ps(_mm256_mul_ps(vf[2], d_v), d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 16, result);
result = _mm256_add_ps(_mm256_mul_ps(vf[3], d_v), d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + 24, result);
} }
} }
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON)
@ -3117,6 +3313,36 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps( d0d1, xy, acc ); acc = _mm256_fmadd_ps( d0d1, xy, acc );
} }
*s = hsum_float_8(acc) + summs;
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
float summs = 0;
// Main loop
for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
summs += x[i].m * (y[i].s0 + y[i].s1);
const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
// Compute combined scales
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
// Accumulate d0*d1*x*y
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
}
*s = hsum_float_8(acc) + summs; *s = hsum_float_8(acc) + summs;
#else #else
// scalar // scalar
@ -3261,6 +3487,35 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(d, q, acc); acc = _mm256_fmadd_ps(d, q, acc);
} }
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8(8);
bx0 = _mm_sub_epi8(bx0, off);
bx1 = _mm_sub_epi8(bx1, off);
__m256i bx = _mm256_set_m128i(bx1, bx0);
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
}
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#else #else
// scalar // scalar
@ -3463,6 +3718,36 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(d, q, acc); acc = _mm256_fmadd_ps(d, q, acc);
} }
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, _mm_set1_epi8((char)0xF0));
bxhih = _mm_andnot_si128(bxhih, _mm_set1_epi8((char)0xF0));
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
}
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#else #else
// scalar // scalar
@ -3667,6 +3952,39 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
} }
*s = hsum_float_8(acc) + summs;
#elif defined(__AVX__)
// change to AVX
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
float summs = 0.0f;
// Main loop
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, _mm_set1_epi8(0x10));
bxhih = _mm_and_si128(bxhih, _mm_set1_epi8(0x10));
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
}
*s = hsum_float_8(acc) + summs; *s = hsum_float_8(acc) + summs;
#else #else
float sumf = 0.0; float sumf = 0.0;
@ -3784,6 +4102,24 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps( d, q, acc ); acc = _mm256_fmadd_ps( d, q, acc );
} }
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
// Multiply q with scale and accumulate
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
}
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#else #else
// scalar // scalar