bug fix for scale backward pass

use sum instead of mean for gradient of scalar scale parameter
This commit is contained in:
xaedes 2023-04-25 22:25:53 +02:00
parent 671e5922e2
commit a367eb9eda
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

2
ggml.c
View file

@ -12802,7 +12802,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad =
ggml_add_impl(ctx,
src1->grad,
ggml_mean(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
inplace);
}
} break;