fix backward pass for rms_norm
I would have used formulas from other frameworks, but they differed so I could not decide which is correct. Instead it was derived here in comment using manual forward-backward automatic differention of rms_norm and simplification.
This commit is contained in:
parent
b18b72da00
commit
84a4b39917
1 changed files with 109 additions and 17 deletions
126
ggml.c
126
ggml.c
|
@ -5709,7 +5709,7 @@ struct ggml_tensor * ggml_rms_norm_back(
|
|||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
// TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
|
@ -9224,34 +9224,126 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|||
const auto i12 = i02;
|
||||
const auto i13 = i03;
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const float * dy = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
ggml_float sum_xx = 0.0;
|
||||
ggml_float sum_xdz = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
sum_xx += (ggml_float)(x[i00] * x[i00]);
|
||||
sum_xdz += (ggml_float)(x[i00] * dz[i00]);
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
const float mean_eps = sum/ne00 + eps;
|
||||
const float mean = sum_xx/ne00;
|
||||
const float mean_eps = sum_xx/ne00 + eps;
|
||||
const float sum_eps = sum_xx + eps*ne00;
|
||||
const float mean_xdz = sum_xdz/ne00;
|
||||
// we could cache rms from forward pass to improve performance.
|
||||
// to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
|
||||
const float rms = sqrtf(mean_eps);
|
||||
const float rrms = 1.0f / sqrtf(mean_eps);
|
||||
const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
|
||||
|
||||
// rms(x) = sqrt(eps + mean(square(x))) ; scalar
|
||||
// y = rms_norm(x) = x/rms(x) = x/sqrt(eps+mean(square(x))) ; vector
|
||||
// dx = dy*(1/rms(x) - square(x)/(n*rms(x)**3))
|
||||
{
|
||||
// z = rms_norm(x)
|
||||
//
|
||||
// rms_norm(src0) =
|
||||
// scale(
|
||||
// src0,
|
||||
// div(
|
||||
// 1,
|
||||
// sqrt(
|
||||
// add(
|
||||
// scale(
|
||||
// sum(
|
||||
// sqr(
|
||||
// src0)),
|
||||
// (1.0/N)),
|
||||
// eps))));
|
||||
|
||||
// postorder:
|
||||
// ## op args grad
|
||||
// 00 param src0 grad[#00]
|
||||
// 01 const 1
|
||||
// 02 sqr (#00) grad[#02]
|
||||
// 03 sum (#02) grad[#03]
|
||||
// 04 const 1/N
|
||||
// 05 scale (#03, #04) grad[#05]
|
||||
// 06 const eps
|
||||
// 07 add (#05, #06) grad[#07]
|
||||
// 08 sqrt (#07) grad[#08]
|
||||
// 09 div (#01,#08) grad[#09]
|
||||
// 10 scale (#00,#09) grad[#10]
|
||||
//
|
||||
// backward pass, given grad[#10]
|
||||
// #10: scale
|
||||
// grad[#00] += scale(grad[#10],#09)
|
||||
// grad[#09] += sum(mul(grad[#10],#00))
|
||||
// #09: div
|
||||
// grad[#08] += neg(mul(grad[#09], div(#09,#08)))
|
||||
// #08: sqrt
|
||||
// grad[#07] += mul(grad[#08], div(0.5, #08))
|
||||
// #07: add
|
||||
// grad[#05] += grad[#07]
|
||||
// #05: scale
|
||||
// grad[#03] += scale(grad[#05],#04)
|
||||
// #03: sum
|
||||
// grad[#02] += repeat(grad[#03], #02)
|
||||
// #02:
|
||||
// grad[#00] += scale(mul(#00, grad[#02]), 2.0)
|
||||
//
|
||||
// substitute and simplify:
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
|
||||
// grad[#02] = repeat(grad[#03], #02)
|
||||
// grad[#02] = repeat(scale(grad[#05],#04), #02)
|
||||
// grad[#02] = repeat(scale(grad[#07],#04), #02)
|
||||
// grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
|
||||
// grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
|
||||
// grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
|
||||
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
|
||||
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
|
||||
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
|
||||
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
|
||||
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
|
||||
// grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
|
||||
// grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
|
||||
// a = b*c + d*e
|
||||
// a = b*c*f/f + d*e*f/f
|
||||
// a = (b*c*f + d*e*f)*(1/f)
|
||||
// a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
|
||||
// a = (b + d*e/c)*c
|
||||
// b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
|
||||
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
|
||||
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
|
||||
// a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
|
||||
// a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
|
||||
// a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
|
||||
// a = (dz + x*div(-mean_xdz,mean_eps))*rrms
|
||||
// grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
|
||||
// grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
|
||||
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
|
||||
}
|
||||
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
|
||||
// post-order:
|
||||
// dx := x
|
||||
// dx := scale(dx,-mean_xdz/mean_eps)
|
||||
// dx := add(dx, dz)
|
||||
// dx := scale(dx, rrms)
|
||||
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
// square(x)
|
||||
ggml_vec_mul_f32(ne00, dx, x, x);
|
||||
// -square(x)/(n*rms**3)
|
||||
ggml_vec_scale_f32(ne00, dx, scale);
|
||||
// 1/rms(x) - square(x)/(n*rms(x)**3)
|
||||
ggml_vec_acc1_f32(ne00, dx, rrms);
|
||||
// dy*(1/rms(x) - square(x)/(n*rms(x)**3))
|
||||
ggml_vec_mul_f32(ne00, dx, dx, dy);
|
||||
|
||||
ggml_vec_cpy_f32(ne00, dx, x);
|
||||
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
|
||||
ggml_vec_scale_f32(ne00, dx, -sum_xdz/sum_eps);
|
||||
ggml_vec_acc_f32(ne00, dx, dz);
|
||||
ggml_vec_scale_f32(ne00, dx, rrms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue