Implement ggml_v_expf() with a fast approximation on AVX/AVX2/AVX512
The code implements a fast, vectorized approximation to exp(x). The approximation trick should work on most systems, but is implemented only for AVX/AVX2/AVX512 to start. Constants used are optimized for explainability rather than minimizing average relative error. The unvectorized implementation looks like: float exp_fast(float x) { int32_t i = 12102203.2f*x+0x3f800000; return *(float *)&i; } The result is accurate to std::exp() within 10% average relative error. Explanation: We know that e^x = 2^(x * log2(e)) The code works by taking advantage of the approximate 2^x computed during float->integer conversions. Doing this correctly requires adjusting the bias with exponent+127 and multiplying by a log2(e) factor to compute e^x instead of 2^x. This allows all of the scaling to happen with a single FMA. log2e constant ~~ 2^23 * log2(e) bias constant ~~ 127.0 * 2^23 The resulting approximation is valid over the +- 88 domain of the exp() function. In order to replicate the behavior of the existing implementation, the commit clamps values outside the domain. This allows the commit to pass CI tests, but also runs somewhat slower than the unclamped implementation.
This commit is contained in:
parent
8841ce3f43
commit
f1fc512752
1 changed files with 33 additions and 99 deletions
132
ggml/src/ggml.c
132
ggml/src/ggml.c
|
@ -2684,35 +2684,17 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
|||
|
||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline static __m512 ggml_v_expf(__m512 x) {
|
||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||
const __m512 n = _mm512_sub_ps(z, r);
|
||||
const __m512 b =
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||
const __mmask16 d =
|
||||
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
||||
const __m512 u = _mm512_mul_ps(b, b);
|
||||
const __m512 j = _mm512_fmadd_ps(
|
||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||
const __m512 res = _mm512_scalef_ps(j, n);
|
||||
if (_mm512_kortestz(d, d))
|
||||
return res;
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 alt = _mm512_mask_blend_ps(
|
||||
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
const __m512 h_lim = _mm512_set1_ps(88.0f);
|
||||
const __m512 l_lim = _mm512_set1_ps(-88.0f);
|
||||
|
||||
const __m512 log2e = _mm512_set1_ps(12102203.2f);
|
||||
const __m512 bias = _mm512_set1_ps(0x3f800000);
|
||||
|
||||
x = _mm512_max_ps(x, l_lim);
|
||||
x = _mm512_min_ps(x, h_lim);
|
||||
|
||||
return _mm512_castsi512_ps(_mm512_cvttps_epi32(_mm512_fmadd_ps(log2e, x, bias)));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
|
@ -2727,47 +2709,18 @@ inline static __m512 ggml_v_silu(__m512 x) {
|
|||
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline static __m256 ggml_v_expf(__m256 x) {
|
||||
const __m256 r = _mm256_set1_ps(0x1.8p23f);
|
||||
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
|
||||
const __m256 n = _mm256_sub_ps(z, r);
|
||||
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
|
||||
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
|
||||
const __m256 k = _mm256_castsi256_ps(
|
||||
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
|
||||
const __m256i c = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(126), _CMP_GT_OQ));
|
||||
const __m256 u = _mm256_mul_ps(b, b);
|
||||
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm256_set1_ps(0x1.573e2ep-5f)), u,
|
||||
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm256_set1_ps(0x1.fffdb6p-2f))),
|
||||
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
|
||||
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
|
||||
return _mm256_fmadd_ps(j, k, k);
|
||||
const __m256i g = _mm256_and_si256(
|
||||
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
|
||||
_mm256_set1_epi32(0x82000000u));
|
||||
const __m256 s1 =
|
||||
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
|
||||
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
|
||||
const __m256i d = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(192), _CMP_GT_OQ));
|
||||
return _mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
|
||||
_mm256_andnot_ps(
|
||||
_mm256_castsi256_ps(d),
|
||||
_mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(c),
|
||||
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
|
||||
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
||||
const __m256 h_lim = _mm256_set1_ps(88.0f);
|
||||
const __m256 l_lim = _mm256_set1_ps(-88.0f);
|
||||
|
||||
const __m256 log2e = _mm256_set1_ps(12102203.2f);
|
||||
const __m256 bias = _mm256_set1_ps(0x3f800000);
|
||||
|
||||
x = _mm256_max_ps(x, l_lim);
|
||||
x = _mm256_min_ps(x, h_lim);
|
||||
|
||||
return _mm256_castsi256_ps(_mm256_cvttps_epi32(_mm256_fmadd_ps(log2e, x, bias)));
|
||||
return x;
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
|
@ -2790,38 +2743,19 @@ inline static __m256 ggml_v_silu(__m256 x) {
|
|||
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
|
||||
#endif
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline static __m128 ggml_v_expf(__m128 x) {
|
||||
const __m128 r = _mm_set1_ps(0x1.8p23f);
|
||||
const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
|
||||
const __m128 n = _mm_sub_ps(z, r);
|
||||
const __m128 b =
|
||||
NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
|
||||
const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
|
||||
const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
|
||||
const __m128i c =
|
||||
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
|
||||
const __m128 u = _mm_mul_ps(b, b);
|
||||
const __m128 j =
|
||||
MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
|
||||
MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
|
||||
u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
|
||||
if (!_mm_movemask_epi8(c))
|
||||
return MADD128(j, k, k);
|
||||
const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
|
||||
_mm_set1_epi32(0x82000000u));
|
||||
const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
|
||||
const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
|
||||
const __m128i d =
|
||||
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
|
||||
return _mm_or_ps(
|
||||
_mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
|
||||
_mm_andnot_ps(_mm_castsi128_ps(d),
|
||||
_mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
|
||||
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
|
||||
const __m128 h_lim = _mm_set1_ps(88.0f);
|
||||
const __m128 l_lim = _mm_set1_ps(-88.0f);
|
||||
|
||||
const __m128 log2e = _mm_set1_ps(12102203.2f); // log(2) * 2^23
|
||||
const __m128 bias = _mm_set1_ps(0x3f800000); // 127 * 2^23
|
||||
|
||||
// Clamping the input range is cheaper than implementing the
|
||||
// out-of-domain behavior of the original implementation
|
||||
x = _mm_max_ps(x, l_lim);
|
||||
x = _mm_min_ps(x, h_lim);
|
||||
|
||||
return _mm_castsi128_ps(_mm_cvttps_epi32(_mm_fmadd_ps(log2e, x, bias)));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue