From 0a85ae73977e69b0fbeb9345361ce8e358919416 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 15:04:53 +0300 Subject: [PATCH] metal : fix GELU kernel numerical stability by using precise::tanh --- ggml-metal.m | 4 ++-- ggml-metal.metal | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1c1810da4..969cf7daa 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -539,8 +539,8 @@ void ggml_metal_graph_compute( id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); for (int ind = node_start; ind < node_end; ++ind) { const int i = has_concur ? ctx->concur_list[ind] : ind; diff --git a/ggml-metal.metal b/ggml-metal.metal index 53604a250..7bc3fdf37 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -87,7 +87,12 @@ kernel void kernel_gelu( device float * dst, uint tpig[[thread_position_in_grid]]) { float x = src0[tpig]; - dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } kernel void kernel_soft_max(