This commit is contained in:
Chris Elrod 2024-05-27 11:29:03 -04:00
parent dc18c34bca
commit a572666e4e
No known key found for this signature in database
GPG key ID: 49A2CC0D19080EA9

9
ggml.c
View file

@ -2309,11 +2309,10 @@ inline static __m512 ggml_v_expf(__m512 x) {
const __m512 b = const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
const __mmask16 d = const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b); const __m512 u = _mm512_mul_ps(b, b);
const __m512 jplus1 = _mm512_fmadd_ps( const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)), _mm512_set1_ps(0x1.573e2ep-5f)),
u, u,
@ -2321,13 +2320,13 @@ inline static __m512 ggml_v_expf(__m512 x) {
_mm512_set1_ps(0x1.fffdb6p-2f))), _mm512_set1_ps(0x1.fffdb6p-2f))),
u, u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 kjk = _mm512_scalef_ps(jplus1, n); const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d)) if (_mm512_kortestz(d, d))
return kjk; return res;
const __m512 zero = _mm512_setzero_ps(); const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps( const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, kjk, alt); return _mm512_mask_blend_ps(d, res, alt);
} }
// computes silu x/(1+exp(-x)) in single precision vector // computes silu x/(1+exp(-x)) in single precision vector