improve accuracy, handle special cases

This commit is contained in:
Chris Elrod 2024-05-27 11:12:37 -04:00
parent ea6e19cd40
commit dc18c34bca
No known key found for this signature in database
GPG key ID: 49A2CC0D19080EA9

44
ggml.c
View file

@ -2303,25 +2303,31 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
// numbers above 88.38 will flush to infinity // numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero // numbers beneath -103.97 will flush to zero
inline static __m512 ggml_v_expf(__m512 x) { inline static __m512 ggml_v_expf(__m512 x) {
// large constant that rounds `float`s; floats are 1 apart at this magnitude const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 round = _mm512_set1_ps(0x1.8p+23); const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
// these `float`s are integer-valued due to +/- `round` const __m512 n = _mm512_sub_ps(z, r);
// intc = round(log_2(e) * x) const __m512 b =
const __m512 intc =_mm512_sub_ps(fma(x, _mm512_set1_ps(0x1.715476p+0), round), round); _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
// r = x - intc / log_2(e) _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
// To attain the needed accuracy, we do this using both the hi and low parts const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
// of `(-1 / log_2(e))` = -log(2), i.e. fma(lo, intc, fma(hi, intc, x)); const __mmask16 d =
// r is in the range [-lcg(2)/2, log(2)/2). _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 r = _mm512_fmadd_ps(_mm512_set1_ps(0x1.05c61p-29), intc, const __m512 u = _mm512_mul_ps(b, b);
_mm512_fmadd_ps(_mm512_set1_ps(-0x1.62e43p-1), intc, x)); const __m512 jplus1 = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
// with the reduced range, we can use a min-maxed polynomial to calculate `e^r`. _mm512_set1_ps(0x1.573e2ep-5f)),
const __m512 expr =_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps( u,
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(0x1.a1d714p-13, r, 0x1.6dd982p-10), _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
r, 0x1.126facp-7), r, 0x1.55541cp-5), r, 0x1.555404p-3), r, 0x1p-1), r, 0x1p+0), _mm512_set1_ps(0x1.fffdb6p-2f))),
r, 0x1p+0); u,
// exp(x) = exp(r + intc/log_2(e)) = exp(r)*exp(intc*log(2)) = exp(r)*pow(2,intc) _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
return _mm512_scalef_ps(expr, intc); const __m512 kjk = _mm512_scalef_ps(jplus1, n);
if (_mm512_kortestz(d, d))
return kjk;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, kjk, alt);
} }
// computes silu x/(1+exp(-x)) in single precision vector // computes silu x/(1+exp(-x)) in single precision vector