AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring

This commit is contained in:
Stephan Walter 2023-04-21 11:20:48 +02:00
parent 8687c1f258
commit 63d8dff4fb

213
ggml.c
View file

@ -487,6 +487,24 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
return bytes;
}
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
return _mm_cvtss_f32(res);
}
// horizontally add 8 int32_t
static inline int hsum_int_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@ -507,6 +525,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
return bytes;
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1);
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
return _mm256_cvtepi32_ps(summed_pairs);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
return sum_i16_pairs_float(dot);
}
static inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@ -1310,29 +1346,6 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
}
}
#ifdef __AVX2__
// There is no better way of doing this?
// I guess not, AVX is not very good at horizontal sums.
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
// keeping the original.
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
//static inline float horizontal_sum(__m256i a) {
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
// __m256i sum = _mm256_add_epi32(a, b);
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
// sum = _mm256_add_epi32(sum, hi);
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
//}
static inline float horizontal_sum(__m256i a) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
__m128i sum64 = _mm_add_epi32(hi64, sum128);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#endif
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
@ -1421,9 +1434,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
y[i].s = d * hsum_int_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -1461,20 +1473,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
// TODO: vectorize this
int sum = 0;
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
y[i].s = y[i].d * sum;
#endif
}
#else
// scalar
quantize_row_q8_0_reference(x, y, k);
#endif
#if defined __AVX__
// TODO: vectorize this
for (int i=0; i<nb; ++i) {
int sum = 0;
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
y[i].s = y[i].d * sum;
}
#endif
}
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
@ -2411,8 +2420,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2478,7 +2485,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
#endif
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -2496,32 +2503,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc );
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -2560,15 +2548,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
*s = hsum_float_8(acc);
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
@ -2590,9 +2573,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
}
sumf += d0*d1*sumi;
}
#endif
*s = sumf;
#endif
}
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2604,8 +2586,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const block_q4_1 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
// TODO: add AVX / WASM SIMD / etc
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -2672,7 +2652,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
#endif
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -2683,7 +2663,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
//const float * m0 = &x[i].m;
summs += x[i].m * y[i].s;
@ -2697,33 +2676,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8( bx, bx );
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8( by, bx );
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
const __m256i ones = _mm256_set1_epi16( 1 );
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
// Convert to vector of 8 int32_t to 8 floats
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
// Accumulate d0*d1*x*y
acc = _mm256_fmadd_ps( d0d1, xy, acc );
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ) + summs;
*s = hsum_float_8(acc) + summs;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float m0 = x[i].m;
@ -2745,9 +2707,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
sumf += f0*f2 + f1*f3;
}
}
#endif
*s = sumf;
#endif
}
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2760,8 +2721,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
const block_q4_2 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2839,7 +2798,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
#endif
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -2861,32 +2820,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps(xy_q);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps(d, q, acc);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps(acc, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
sumf = _mm_cvtss_f32(res);
*s = hsum_float_8(acc);
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs;
@ -2921,9 +2864,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
sumf += (d0 * y[i].d) * sumi_0;
sumf += (d1 * y[i].d) * sumi_1;
}
#endif
*s = sumf;
#endif
}
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2936,8 +2878,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const block_q4_3 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -3023,9 +2963,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 dx = _mm256_set_m128(d1, d0);
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
const __m256 mx = _mm256_set_m128(m1, m0);
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
const __m256i bx = _mm256_set_m128i(bx1, bx0);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
const __m256 syf = sum_i16_pairs_float(syi);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
acc = _mm256_fmadd_ps(sxy, dy, acc);
}
*s = hsum_float_8(acc);
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs;
@ -3068,9 +3040,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
}
#endif
*s = sumf;
#endif
}