faster avx512 exp implementation
This commit is contained in:
parent
dff451cfa1
commit
bf5261e04e
1 changed files with 19 additions and 29 deletions
48
ggml.c
48
ggml.c
|
@ -2303,35 +2303,25 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
||||||
// numbers above 88.38 will flush to infinity
|
// numbers above 88.38 will flush to infinity
|
||||||
// numbers beneath -103.97 will flush to zero
|
// numbers beneath -103.97 will flush to zero
|
||||||
inline static __m512 ggml_v_expf(__m512 x) {
|
inline static __m512 ggml_v_expf(__m512 x) {
|
||||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
// large constant that rounds `float`s; floats are 1 apart at this magnitude
|
||||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
const __m512 round = _mm512_set1_ps(0x1.8p+23);
|
||||||
const __m512 n = _mm512_sub_ps(z, r);
|
// these `float`s are integer-valued due to +/- `round`
|
||||||
const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
// intc = round(log_2(e) * x)
|
||||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
const __m512 intc =_mm512_sub_ps(fma(x, _mm512_set1_ps(0x1.715476p+0), round), round);
|
||||||
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
|
// r = x - intc / log_2(e)
|
||||||
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
|
// To attain the needed accuracy, we do this using both the hi and low parts
|
||||||
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
|
// of `(-1 / log_2(e))` = -log(2), i.e. fma(lo, intc, fma(hi, intc, x));
|
||||||
const __m512 u = _mm512_mul_ps(b, b);
|
// r is in the range [-lcg(2)/2, log(2)/2).
|
||||||
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
const __m512 r = _mm512_fmadd_ps(_mm512_set1_ps(0x1.05c61p-29), intc,
|
||||||
_mm512_set1_ps(0x1.573e2ep-5f)), u,
|
_mm512_fmadd_ps(_mm512_set1_ps(-0x1.62e43p-1), intc, x));
|
||||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
|
||||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
// with the reduced range, we can use a min-maxed polynomial to calculate `e^r`.
|
||||||
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
|
const __m512 expr =_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(
|
||||||
if (_mm512_kortestz(c, c))
|
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(0x1.a1d714p-13, x, 0x1.6dd982p-10),
|
||||||
return _mm512_fmadd_ps(j, k, k);
|
x, 0x1.126facp-7), x, 0x1.55541cp-5), x, 0x1.555404p-3), x, 0x1p-1), x, 0x1p+0),
|
||||||
const __m512i g = _mm512_and_si512(
|
x, 0x1p+0);
|
||||||
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
|
// exp(x) = exp(r + intc/log_2(e)) = exp(r)*exp(intc*log(2)) = exp(r)*pow(2,intc)
|
||||||
_mm512_set1_epi32(0x82000000u));
|
return _mm512_scalef_ps(expr, intc);
|
||||||
const __m512 s1 =
|
|
||||||
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
|
|
||||||
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
|
|
||||||
const __mmask16 d =
|
|
||||||
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
|
||||||
return _mm512_mask_blend_ps(
|
|
||||||
d, _mm512_mask_blend_ps(
|
|
||||||
c, _mm512_fmadd_ps(k, j, k),
|
|
||||||
_mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
|
|
||||||
_mm512_mul_ps(s1, s1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// computes silu x/(1+exp(-x)) in single precision vector
|
// computes silu x/(1+exp(-x)) in single precision vector
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue