bug fix for reshape backward pass
This commit is contained in:
parent
b2bd8222da
commit
c483a7dac5
1 changed files with 2 additions and 2 deletions
4
ggml.c
4
ggml.c
|
@ -5863,7 +5863,7 @@ struct ggml_tensor * ggml_reshape(
|
|||
|
||||
if (b->grad) {
|
||||
// gradient propagation is not supported
|
||||
GGML_ASSERT(false);
|
||||
//GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
|
||||
|
@ -12830,7 +12830,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_reshape(ctx, tensor->grad, src1),
|
||||
ggml_reshape(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
}
|
||||
if (src1->grad) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue