diff --git a/ggml.c b/ggml.c index d718de33b..07d100bf0 100644 --- a/ggml.c +++ b/ggml.c @@ -14334,7 +14334,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( if (ith == 0) { float * dp = (float *) dst->data; ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f; + dp[0] *= -1.0f / (float) nr; } return; } @@ -14506,7 +14506,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( ggml_vec_scale_f32(nc, ds0, sum); ggml_vec_add1_f32(nc, ds0, ds0, eps); ggml_vec_sub_f32(nc, ds0, ds0, s1); - ggml_vec_scale_f32(nc, ds0, d[0]); + ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); #ifndef NDEBUG