From dc18c34bcaf6f114af166384b8b5c8110c4e1429 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 27 May 2024 11:12:37 -0400 Subject: [PATCH] improve accuracy, handle special cases --- ggml.c | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/ggml.c b/ggml.c index 19bfe331a..22427a422 100644 --- a/ggml.c +++ b/ggml.c @@ -2303,25 +2303,31 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) { // numbers above 88.38 will flush to infinity // numbers beneath -103.97 will flush to zero inline static __m512 ggml_v_expf(__m512 x) { - // large constant that rounds `float`s; floats are 1 apart at this magnitude - const __m512 round = _mm512_set1_ps(0x1.8p+23); - // these `float`s are integer-valued due to +/- `round` - // intc = round(log_2(e) * x) - const __m512 intc =_mm512_sub_ps(fma(x, _mm512_set1_ps(0x1.715476p+0), round), round); - // r = x - intc / log_2(e) - // To attain the needed accuracy, we do this using both the hi and low parts - // of `(-1 / log_2(e))` = -log(2), i.e. fma(lo, intc, fma(hi, intc, x)); - // r is in the range [-lcg(2)/2, log(2)/2). - const __m512 r = _mm512_fmadd_ps(_mm512_set1_ps(0x1.05c61p-29), intc, - _mm512_fmadd_ps(_mm512_set1_ps(-0x1.62e43p-1), intc, x)); - - // with the reduced range, we can use a min-maxed polynomial to calculate `e^r`. - const __m512 expr =_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps( - _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(0x1.a1d714p-13, r, 0x1.6dd982p-10), - r, 0x1.126facp-7), r, 0x1.55541cp-5), r, 0x1.555404p-3), r, 0x1p-1), r, 0x1p+0), - r, 0x1p+0); - // exp(x) = exp(r + intc/log_2(e)) = exp(r)*exp(intc*log(2)) = exp(r)*pow(2,intc) - return _mm512_scalef_ps(expr, intc); + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 jplus1 = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 kjk = _mm512_scalef_ps(jplus1, n); + if (_mm512_kortestz(d, d)) + return kjk; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, kjk, alt); } // computes silu x/(1+exp(-x)) in single precision vector