diff --git a/ggml.c b/ggml.c index ef9d3d692..40ee3eeb1 100644 --- a/ggml.c +++ b/ggml.c @@ -9831,15 +9831,15 @@ static void ggml_compute_forward_rms_norm_back_f32( sum_xdz += (ggml_float)(x[i00] * dz[i00]); } - const float mean = sum_xx/ne00; - const float mean_eps = sum_xx/ne00 + eps; - const float sum_eps = sum_xx + eps*ne00; - const float mean_xdz = sum_xdz/ne00; + const ggml_float mean = sum_xx/ne00; + const ggml_float mean_eps = sum_xx/ne00 + eps; + const ggml_float sum_eps = sum_xx + eps*ne00; + const ggml_float mean_xdz = sum_xdz/ne00; // we could cache rms from forward pass to improve performance. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. - const float rms = sqrtf(mean_eps); - const float rrms = 1.0f / sqrtf(mean_eps); - const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + const ggml_float rms = sqrtf(mean_eps); + const ggml_float rrms = 1.0f / sqrtf(mean_eps); + const ggml_float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) { // z = rms_norm(x)