Implement ggml_v_expf() with a fast approximation on AVX/AVX2/AVX512

The code implements a fast, vectorized approximation to exp(x). The
approximation trick should work on most systems, but is implemented
only for AVX/AVX2/AVX512 to start.

Constants used are optimized for explainability rather than minimizing
average relative error.

The unvectorized implementation looks like:

float exp_fast(float x) {
    int32_t i = 12102203.2f*x+0x3f800000;
    return *(float *)&i;
}

The result is accurate to std::exp() within 10% average relative error.

Explanation:

We know that
e^x = 2^(x * log2(e))

The code works by taking advantage of the approximate 2^x computed
during float->integer conversions. Doing this correctly requires adjusting
the bias with exponent+127 and multiplying by a log2(e) factor to compute
e^x instead of 2^x. This allows all of the scaling to happen with a single
FMA.

log2e constant ~~ 2^23 * log2(e)
bias constant ~~ 127.0 * 2^23

The resulting approximation is valid over the +- 88 domain of the exp()
function. In order to replicate the behavior of the existing
implementation, the commit clamps values outside the domain. This allows
the commit to pass CI tests, but also runs somewhat slower than the
unclamped implementation.
This commit is contained in:
J M 2024-10-27 12:34:32 -07:00
parent 8841ce3f43
commit f1fc512752

View file

@ -2684,35 +2684,17 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static __m512 ggml_v_expf(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d))
return res;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, res, alt);
const __m512 h_lim = _mm512_set1_ps(88.0f);
const __m512 l_lim = _mm512_set1_ps(-88.0f);
const __m512 log2e = _mm512_set1_ps(12102203.2f);
const __m512 bias = _mm512_set1_ps(0x3f800000);
x = _mm512_max_ps(x, l_lim);
x = _mm512_min_ps(x, h_lim);
return _mm512_castsi512_ps(_mm512_cvttps_epi32(_mm512_fmadd_ps(log2e, x, bias)));
}
// computes silu x/(1+exp(-x)) in single precision vector
@ -2727,47 +2709,18 @@ inline static __m512 ggml_v_silu(__m512 x) {
#elif defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static __m256 ggml_v_expf(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)), u,
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return _mm256_fmadd_ps(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
const __m256 h_lim = _mm256_set1_ps(88.0f);
const __m256 l_lim = _mm256_set1_ps(-88.0f);
const __m256 log2e = _mm256_set1_ps(12102203.2f);
const __m256 bias = _mm256_set1_ps(0x3f800000);
x = _mm256_max_ps(x, l_lim);
x = _mm256_min_ps(x, h_lim);
return _mm256_castsi256_ps(_mm256_cvttps_epi32(_mm256_fmadd_ps(log2e, x, bias)));
return x;
}
// computes silu x/(1+exp(-x)) in single precision vector
@ -2790,38 +2743,19 @@ inline static __m256 ggml_v_silu(__m256 x) {
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
#endif
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static __m128 ggml_v_expf(__m128 x) {
const __m128 r = _mm_set1_ps(0x1.8p23f);
const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
const __m128 n = _mm_sub_ps(z, r);
const __m128 b =
NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
const __m128i c =
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
const __m128 u = _mm_mul_ps(b, b);
const __m128 j =
MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
if (!_mm_movemask_epi8(c))
return MADD128(j, k, k);
const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
_mm_set1_epi32(0x82000000u));
const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
const __m128i d =
_mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
return _mm_or_ps(
_mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
_mm_andnot_ps(_mm_castsi128_ps(d),
_mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
const __m128 h_lim = _mm_set1_ps(88.0f);
const __m128 l_lim = _mm_set1_ps(-88.0f);
const __m128 log2e = _mm_set1_ps(12102203.2f); // log(2) * 2^23
const __m128 bias = _mm_set1_ps(0x3f800000); // 127 * 2^23
// Clamping the input range is cheaper than implementing the
// out-of-domain behavior of the original implementation
x = _mm_max_ps(x, l_lim);
x = _mm_min_ps(x, h_lim);
return _mm_castsi128_ps(_mm_cvttps_epi32(_mm_fmadd_ps(log2e, x, bias)));
}
// computes silu x/(1+exp(-x)) in single precision vector