bug fix for reshape backward pass
This commit is contained in:
parent
b2bd8222da
commit
c483a7dac5
1 changed files with 2 additions and 2 deletions
4
ggml.c
4
ggml.c
|
@ -5863,7 +5863,7 @@ struct ggml_tensor * ggml_reshape(
|
||||||
|
|
||||||
if (b->grad) {
|
if (b->grad) {
|
||||||
// gradient propagation is not supported
|
// gradient propagation is not supported
|
||||||
GGML_ASSERT(false);
|
//GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
|
||||||
|
@ -12830,7 +12830,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_impl(ctx, src0->grad,
|
ggml_add_impl(ctx, src0->grad,
|
||||||
ggml_reshape(ctx, tensor->grad, src1),
|
ggml_reshape(ctx, tensor->grad, src0->grad),
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue