ggml : update ggml_rms_norm_back with configurable eps

This commit is contained in:
xaedes 2023-07-28 23:09:56 +02:00
parent 87035b96f7
commit ecdc16163e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 12 additions and 5 deletions

13
ggml.c
View file

@ -5824,7 +5824,8 @@ struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
struct ggml_tensor * b,
float eps) {
bool is_node = false;
if (a->grad) {
@ -5834,6 +5835,8 @@ struct ggml_tensor * ggml_rms_norm_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_RMS_NORM_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
@ -10211,7 +10214,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
GGML_TENSOR_BINARY_OP_LOCALS;
const float eps = 1e-6f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -15029,9 +15033,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rms_norm_back(ctx, src0, tensor->grad),
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
inplace);
}
} break;

4
ggml.h
View file

@ -894,11 +894,11 @@ extern "C" {
// a - x
// b - dy
// TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b,
float eps);
// A: n columns, m rows
// B: n columns, p rows (i.e. we transpose it internally)