From 84e8da665d3982aede690a6c7244ff4b37ee5d7d Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Thu, 24 Aug 2023 15:13:18 +0200 Subject: [PATCH] ggml.c : use ggml_float for gelu --- ggml.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml.c b/ggml.c index 7d40cf815..15206ea36 100644 --- a/ggml.c +++ b/ggml.c @@ -3554,12 +3554,13 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float GELU_QUICK_COEF = -1.702f; +static const ggml_float GELU_COEF_A = 0.044715; +static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; inline static float ggml_gelu_f32(float x) { - return 0.5*(double)x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + const ggml_float xx = (ggml_float) x; + return 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx))); } inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {