bug fix for reshape backward pass

This commit is contained in:
xaedes 2023-04-26 20:34:08 +02:00
parent b2bd8222da
commit c483a7dac5
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

4
ggml.c
View file

@ -5863,7 +5863,7 @@ struct ggml_tensor * ggml_reshape(
if (b->grad) {
// gradient propagation is not supported
GGML_ASSERT(false);
//GGML_ASSERT(false);
}
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
@ -12830,7 +12830,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
src0->grad =
ggml_add_impl(ctx, src0->grad,
ggml_reshape(ctx, tensor->grad, src1),
ggml_reshape(ctx, tensor->grad, src0->grad),
inplace);
}
if (src1->grad) {