improve accuracy, handle special cases
This commit is contained in:
parent
ea6e19cd40
commit
dc18c34bca
1 changed files with 25 additions and 19 deletions
44
ggml.c
44
ggml.c
|
@ -2303,25 +2303,31 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
||||||
// numbers above 88.38 will flush to infinity
|
// numbers above 88.38 will flush to infinity
|
||||||
// numbers beneath -103.97 will flush to zero
|
// numbers beneath -103.97 will flush to zero
|
||||||
inline static __m512 ggml_v_expf(__m512 x) {
|
inline static __m512 ggml_v_expf(__m512 x) {
|
||||||
// large constant that rounds `float`s; floats are 1 apart at this magnitude
|
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||||
const __m512 round = _mm512_set1_ps(0x1.8p+23);
|
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||||
// these `float`s are integer-valued due to +/- `round`
|
const __m512 n = _mm512_sub_ps(z, r);
|
||||||
// intc = round(log_2(e) * x)
|
const __m512 b =
|
||||||
const __m512 intc =_mm512_sub_ps(fma(x, _mm512_set1_ps(0x1.715476p+0), round), round);
|
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||||
// r = x - intc / log_2(e)
|
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||||
// To attain the needed accuracy, we do this using both the hi and low parts
|
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
|
||||||
// of `(-1 / log_2(e))` = -log(2), i.e. fma(lo, intc, fma(hi, intc, x));
|
const __mmask16 d =
|
||||||
// r is in the range [-lcg(2)/2, log(2)/2).
|
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
||||||
const __m512 r = _mm512_fmadd_ps(_mm512_set1_ps(0x1.05c61p-29), intc,
|
const __m512 u = _mm512_mul_ps(b, b);
|
||||||
_mm512_fmadd_ps(_mm512_set1_ps(-0x1.62e43p-1), intc, x));
|
const __m512 jplus1 = _mm512_fmadd_ps(
|
||||||
|
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||||
// with the reduced range, we can use a min-maxed polynomial to calculate `e^r`.
|
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||||
const __m512 expr =_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(
|
u,
|
||||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(0x1.a1d714p-13, r, 0x1.6dd982p-10),
|
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||||
r, 0x1.126facp-7), r, 0x1.55541cp-5), r, 0x1.555404p-3), r, 0x1p-1), r, 0x1p+0),
|
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||||
r, 0x1p+0);
|
u,
|
||||||
// exp(x) = exp(r + intc/log_2(e)) = exp(r)*exp(intc*log(2)) = exp(r)*pow(2,intc)
|
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||||
return _mm512_scalef_ps(expr, intc);
|
const __m512 kjk = _mm512_scalef_ps(jplus1, n);
|
||||||
|
if (_mm512_kortestz(d, d))
|
||||||
|
return kjk;
|
||||||
|
const __m512 zero = _mm512_setzero_ps();
|
||||||
|
const __m512 alt = _mm512_mask_blend_ps(
|
||||||
|
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||||
|
return _mm512_mask_blend_ps(d, kjk, alt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// computes silu x/(1+exp(-x)) in single precision vector
|
// computes silu x/(1+exp(-x)) in single precision vector
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue