fix backward pass for rms_norm

I would have used formulas from other frameworks, but they differed so I could not decide which is correct.
Instead it was derived here in comment using manual forward-backward automatic differention of rms_norm and simplification.
This commit is contained in:
xaedes 2023-04-30 21:34:21 +02:00
parent b18b72da00
commit 84a4b39917
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

126
ggml.c
View file

@ -5709,7 +5709,7 @@ struct ggml_tensor * ggml_rms_norm_back(
bool is_node = false;
if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
// TODO: implement backward
is_node = true;
}
@ -9224,34 +9224,126 @@ static void ggml_compute_forward_rms_norm_back_f32(
const auto i12 = i02;
const auto i13 = i03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * dy = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_float sum = 0.0;
ggml_float sum_xx = 0.0;
ggml_float sum_xdz = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
sum_xx += (ggml_float)(x[i00] * x[i00]);
sum_xdz += (ggml_float)(x[i00] * dz[i00]);
}
const float mean = sum/ne00;
const float mean_eps = sum/ne00 + eps;
const float mean = sum_xx/ne00;
const float mean_eps = sum_xx/ne00 + eps;
const float sum_eps = sum_xx + eps*ne00;
const float mean_xdz = sum_xdz/ne00;
// we could cache rms from forward pass to improve performance.
// to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
const float rms = sqrtf(mean_eps);
const float rrms = 1.0f / sqrtf(mean_eps);
const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
// rms(x) = sqrt(eps + mean(square(x))) ; scalar
// y = rms_norm(x) = x/rms(x) = x/sqrt(eps+mean(square(x))) ; vector
// dx = dy*(1/rms(x) - square(x)/(n*rms(x)**3))
{
// z = rms_norm(x)
//
// rms_norm(src0) =
// scale(
// src0,
// div(
// 1,
// sqrt(
// add(
// scale(
// sum(
// sqr(
// src0)),
// (1.0/N)),
// eps))));
// postorder:
// ## op args grad
// 00 param src0 grad[#00]
// 01 const 1
// 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03]
// 04 const 1/N
// 05 scale (#03, #04) grad[#05]
// 06 const eps
// 07 add (#05, #06) grad[#07]
// 08 sqrt (#07) grad[#08]
// 09 div (#01,#08) grad[#09]
// 10 scale (#00,#09) grad[#10]
//
// backward pass, given grad[#10]
// #10: scale
// grad[#00] += scale(grad[#10],#09)
// grad[#09] += sum(mul(grad[#10],#00))
// #09: div
// grad[#08] += neg(mul(grad[#09], div(#09,#08)))
// #08: sqrt
// grad[#07] += mul(grad[#08], div(0.5, #08))
// #07: add
// grad[#05] += grad[#07]
// #05: scale
// grad[#03] += scale(grad[#05],#04)
// #03: sum
// grad[#02] += repeat(grad[#03], #02)
// #02:
// grad[#00] += scale(mul(#00, grad[#02]), 2.0)
//
// substitute and simplify:
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
// grad[#02] = repeat(grad[#03], #02)
// grad[#02] = repeat(scale(grad[#05],#04), #02)
// grad[#02] = repeat(scale(grad[#07],#04), #02)
// grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
// grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
// grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
// grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
// grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
// a = b*c + d*e
// a = b*c*f/f + d*e*f/f
// a = (b*c*f + d*e*f)*(1/f)
// a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
// a = (b + d*e/c)*c
// b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
// a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
// a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
// a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
// a = (dz + x*div(-mean_xdz,mean_eps))*rrms
// grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
// grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
}
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
// post-order:
// dx := x
// dx := scale(dx,-mean_xdz/mean_eps)
// dx := add(dx, dz)
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// square(x)
ggml_vec_mul_f32(ne00, dx, x, x);
// -square(x)/(n*rms**3)
ggml_vec_scale_f32(ne00, dx, scale);
// 1/rms(x) - square(x)/(n*rms(x)**3)
ggml_vec_acc1_f32(ne00, dx, rrms);
// dy*(1/rms(x) - square(x)/(n*rms(x)**3))
ggml_vec_mul_f32(ne00, dx, dx, dy);
ggml_vec_cpy_f32(ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, -sum_xdz/sum_eps);
ggml_vec_acc_f32(ne00, dx, dz);
ggml_vec_scale_f32(ne00, dx, rrms);
}
}
}