From ecdc16163efa41fc41ac2dfca63cb7af60e2362c Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 28 Jul 2023 23:09:56 +0200 Subject: [PATCH] ggml : update ggml_rms_norm_back with configurable eps --- ggml.c | 13 ++++++++++--- ggml.h | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index 4ddd154bf..756000cff 100644 --- a/ggml.c +++ b/ggml.c @@ -5824,7 +5824,8 @@ struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * b, + float eps) { bool is_node = false; if (a->grad) { @@ -5834,6 +5835,8 @@ struct ggml_tensor * ggml_rms_norm_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &eps, sizeof(eps)); + result->op = GGML_OP_RMS_NORM_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; @@ -10211,7 +10214,8 @@ static void ggml_compute_forward_rms_norm_back_f32( GGML_TENSOR_BINARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -15029,9 +15033,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad), + ggml_rms_norm_back(ctx, src0, tensor->grad, eps), inplace); } } break; diff --git a/ggml.h b/ggml.h index 3980c0050..9e8ed956e 100644 --- a/ggml.h +++ b/ggml.h @@ -894,11 +894,11 @@ extern "C" { // a - x // b - dy - // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float eps); // A: n columns, m rows // B: n columns, p rows (i.e. we transpose it internally)