diff --git a/ggml.c b/ggml.c index 9b65ec822..85e5e941c 100644 --- a/ggml.c +++ b/ggml.c @@ -5709,7 +5709,7 @@ struct ggml_tensor * ggml_rms_norm_back( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + // TODO: implement backward is_node = true; } @@ -9224,34 +9224,126 @@ static void ggml_compute_forward_rms_norm_back_f32( const auto i12 = i02; const auto i13 = i03; const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dy = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); - ggml_float sum = 0.0; + ggml_float sum_xx = 0.0; + ggml_float sum_xdz = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); + sum_xx += (ggml_float)(x[i00] * x[i00]); + sum_xdz += (ggml_float)(x[i00] * dz[i00]); } - const float mean = sum/ne00; - const float mean_eps = sum/ne00 + eps; + const float mean = sum_xx/ne00; + const float mean_eps = sum_xx/ne00 + eps; + const float sum_eps = sum_xx + eps*ne00; + const float mean_xdz = sum_xdz/ne00; // we could cache rms from forward pass to improve performance. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. const float rms = sqrtf(mean_eps); const float rrms = 1.0f / sqrtf(mean_eps); const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) - // rms(x) = sqrt(eps + mean(square(x))) ; scalar - // y = rms_norm(x) = x/rms(x) = x/sqrt(eps+mean(square(x))) ; vector - // dx = dy*(1/rms(x) - square(x)/(n*rms(x)**3)) + { + // z = rms_norm(x) + // + // rms_norm(src0) = + // scale( + // src0, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src0)), + // (1.0/N)), + // eps)))); + // postorder: + // ## op args grad + // 00 param src0 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - // square(x) - ggml_vec_mul_f32(ne00, dx, x, x); - // -square(x)/(n*rms**3) - ggml_vec_scale_f32(ne00, dx, scale); - // 1/rms(x) - square(x)/(n*rms(x)**3) - ggml_vec_acc1_f32(ne00, dx, rrms); - // dy*(1/rms(x) - square(x)/(n*rms(x)**3)) - ggml_vec_mul_f32(ne00, dx, dx, dy); + + ggml_vec_cpy_f32(ne00, dx, x); + // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + ggml_vec_scale_f32(ne00, dx, -sum_xdz/sum_eps); + ggml_vec_acc_f32(ne00, dx, dz); + ggml_vec_scale_f32(ne00, dx, rrms); } } }