- Use 4 accumulations instead of 2

- Removed first accumulation

ideas taken from here https://stackoverflow.blog/2020/07/08/improving-performance-with-simd-intrinsics-in-three-use-cases/

llama_print_timings:        load time =  3087.59 ms
llama_print_timings:      sample time =   132.04 ms /   128 runs   (    1.03 ms per token)
llama_print_timings: prompt eval time =  2894.28 ms /     8 tokens (  361.78 ms per token)
llama_print_timings:        eval time = 58529.67 ms /   127 runs   (  460.86 ms per token)
llama_print_timings:       total time = 61780.98 ms
This commit is contained in:
3ooabkhxtn 2023-05-12 09:15:46 +00:00
parent 607b9c7373
commit 78bbb3cdfe

121
ggml.c
View file

@ -501,6 +501,17 @@ static inline float hsum_float_2x4(const __m128 x, const __m128 y) {
return _mm_cvtss_f32(res);
}
// horizontally add 4x4 floats
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
__m128 res_0 =_mm_hadd_ps(a, b);
__m128 res_1 =_mm_hadd_ps(c, d);
__m128 res =_mm_hadd_ps(res_0, res_1);
res =_mm_hadd_ps(res, res);
res =_mm_hadd_ps(res, res);
return _mm_cvtss_f32(res);
}
#endif
#if __AVX__ || __AVX2__ || __AVX512F__
@ -2149,44 +2160,112 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
*s = hsum_float_8(acc);
#elif defined(__SSE3__)
// set constants
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
// Initialize accumulator with zeros
__m128 acc_0 = _mm_setzero_ps();
__m128 acc_1 = _mm_setzero_ps();
__m128 acc_2 = _mm_setzero_ps();
__m128 acc_3 = _mm_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
{
// Compute combined scale for the block
const __m128 d = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
bx_1 = _mm_sub_epi8(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
// Compute combined scale for the block
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
bx_2 = _mm_sub_epi8(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
bx_3 = _mm_sub_epi8(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
// Convert int32_t to float
__m128 p0 = _mm_cvtepi32_ps(i32_0);
__m128 p1 = _mm_cvtepi32_ps(i32_1);
__m128 p2 = _mm_cvtepi32_ps(i32_2);
__m128 p3 = _mm_cvtepi32_ps(i32_3);
// Apply the scale
__m128 p0_d = _mm_mul_ps( d, p0 );
__m128 p1_d = _mm_mul_ps( d, p1 );
// Accumulate
acc_0 = _mm_add_ps(p0_d, acc_0);
acc_1 = _mm_add_ps(p1_d, acc_1);
acc_0 = _mm_mul_ps( d_0_1, p0 );
acc_1 = _mm_mul_ps( d_0_1, p1 );
acc_2 = _mm_mul_ps( d_2_3, p2 );
acc_3 = _mm_mul_ps( d_2_3, p3 );
}
*s = hsum_float_2x4(acc_0, acc_1);
// Main loop
for (int i = 2; i < nb; i+=2) {
// Compute combined scale for the block
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx_1 = _mm_sub_epi8(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
bx_2 = _mm_sub_epi8(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
bx_3 = _mm_sub_epi8(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
// Convert int32_t to float
__m128 p0 = _mm_cvtepi32_ps(i32_0);
__m128 p1 = _mm_cvtepi32_ps(i32_1);
__m128 p2 = _mm_cvtepi32_ps(i32_2);
__m128 p3 = _mm_cvtepi32_ps(i32_3);
// Apply the scale
__m128 p0_d = _mm_mul_ps( d_0_1, p0 );
__m128 p1_d = _mm_mul_ps( d_0_1, p1 );
__m128 p2_d = _mm_mul_ps( d_2_3, p2 );
__m128 p3_d = _mm_mul_ps( d_2_3, p3 );
// Acummulate
acc_0 = _mm_add_ps(p0_d, acc_0);
acc_1 = _mm_add_ps(p1_d, acc_1);
acc_2 = _mm_add_ps(p2_d, acc_2);
acc_3 = _mm_add_ps(p3_d, acc_3);
}
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
#else
// scalar
float sumf = 0.0;