slightly improve how cross entropy loss is compute

btw: directly implemented cross entropy loss seems to have way lower magnitudes than when implemented with softmax and log.
probably the input to log gets closer to zero due to float numerics.
maybe the multiplication by (1.0-eps)/sum is more accurate..
This commit is contained in:
xaedes 2023-05-28 22:40:58 +02:00
parent 5f5aa20078
commit 89475fb320
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

4
ggml.c
View file

@ -12961,10 +12961,10 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
}
assert(sum > 0.0);
sum = 1.0/sum;
// sum = 1.0/sum;
}
// avoid log(0) by rescaling from [0..1] to [eps..1]
sum = sum * (1.0f - eps);
sum = (1.0f - eps) / sum;
ggml_vec_scale_f32(nc, st, sum);
ggml_vec_add1_f32(nc, st, st, eps);
ggml_vec_log_f32(nc, st, st);