ggml : update ggml_rms_norm_back with configurable eps
This commit is contained in:
parent
87035b96f7
commit
ecdc16163e
2 changed files with 12 additions and 5 deletions
13
ggml.c
13
ggml.c
|
@ -5824,7 +5824,8 @@ struct ggml_tensor * ggml_rms_norm_inplace(
|
|||
struct ggml_tensor * ggml_rms_norm_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
struct ggml_tensor * b,
|
||||
float eps) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -5834,6 +5835,8 @@ struct ggml_tensor * ggml_rms_norm_back(
|
|||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||
|
||||
result->op = GGML_OP_RMS_NORM_BACK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
@ -10211,7 +10214,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
const float eps = 1e-6f; // TODO: make this a parameter
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -15029,9 +15033,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
float eps;
|
||||
memcpy(&eps, tensor->op_params, sizeof(float));
|
||||
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_rms_norm_back(ctx, src0, tensor->grad),
|
||||
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
||||
inplace);
|
||||
}
|
||||
} break;
|
||||
|
|
4
ggml.h
4
ggml.h
|
@ -894,11 +894,11 @@ extern "C" {
|
|||
|
||||
// a - x
|
||||
// b - dy
|
||||
// TODO: update with configurable eps
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
// A: n columns, m rows
|
||||
// B: n columns, p rows (i.e. we transpose it internally)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue