ggml.c : use ggml_float for gelu

This commit is contained in:
klosax 2023-08-24 15:13:18 +02:00 committed by GitHub
parent 797312e758
commit 84e8da665d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

9
ggml.c
View file

@ -3554,12 +3554,13 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float GELU_QUICK_COEF = -1.702f;
static const ggml_float GELU_COEF_A = 0.044715;
static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
inline static float ggml_gelu_f32(float x) {
return 0.5*(double)x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
const ggml_float xx = (ggml_float) x;
return 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx)));
}
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {