From bf5261e04e264c4010a86d2d2d6d9a67202fe59e Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Sun, 26 May 2024 21:07:13 -0400 Subject: [PATCH] faster avx512 exp implementation --- ggml.c | 48 +++++++++++++++++++----------------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/ggml.c b/ggml.c index 5145ceec9..6c97dfbe9 100644 --- a/ggml.c +++ b/ggml.c @@ -2303,35 +2303,25 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) { // numbers above 88.38 will flush to infinity // numbers beneath -103.97 will flush to zero inline static __m512 ggml_v_expf(__m512 x) { - const __m512 r = _mm512_set1_ps(0x1.8p23f); - const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); - const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23); - const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1)))); - const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b)); - if (_mm512_kortestz(c, c)) - return _mm512_fmadd_ps(j, k, k); - const __m512i g = _mm512_and_si512( - _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)), - _mm512_set1_epi32(0x82000000u)); - const __m512 s1 = - _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u))); - const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g)); - const __mmask16 d = - _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - return _mm512_mask_blend_ps( - d, _mm512_mask_blend_ps( - c, _mm512_fmadd_ps(k, j, k), - _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)), - _mm512_mul_ps(s1, s1)); + // large constant that rounds `float`s; floats are 1 apart at this magnitude + const __m512 round = _mm512_set1_ps(0x1.8p+23); + // these `float`s are integer-valued due to +/- `round` + // intc = round(log_2(e) * x) + const __m512 intc =_mm512_sub_ps(fma(x, _mm512_set1_ps(0x1.715476p+0), round), round); + // r = x - intc / log_2(e) + // To attain the needed accuracy, we do this using both the hi and low parts + // of `(-1 / log_2(e))` = -log(2), i.e. fma(lo, intc, fma(hi, intc, x)); + // r is in the range [-lcg(2)/2, log(2)/2). + const __m512 r = _mm512_fmadd_ps(_mm512_set1_ps(0x1.05c61p-29), intc, + _mm512_fmadd_ps(_mm512_set1_ps(-0x1.62e43p-1), intc, x)); + + // with the reduced range, we can use a min-maxed polynomial to calculate `e^r`. + const __m512 expr =_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(0x1.a1d714p-13, x, 0x1.6dd982p-10), + x, 0x1.126facp-7), x, 0x1.55541cp-5), x, 0x1.555404p-3), x, 0x1p-1), x, 0x1p+0), + x, 0x1p+0); + // exp(x) = exp(r + intc/log_2(e)) = exp(r)*exp(intc*log(2)) = exp(r)*pow(2,intc) + return _mm512_scalef_ps(expr, intc); } // computes silu x/(1+exp(-x)) in single precision vector