implement ggml_cont backward pass

This commit is contained in:
xaedes 2023-04-28 18:12:25 +02:00
parent 02d3fd0894
commit 3d21f2646e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

11
ggml.c
View file

@ -5822,7 +5822,6 @@ struct ggml_tensor * ggml_cont_impl(
bool is_node = false;
if (!inplace && a->grad) {
// TODO: implement backward
is_node = true;
}
@ -13188,7 +13187,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_CONT:
{
GGML_ASSERT(false); // TODO: not implemented
// same as cpy
if (src0->grad) {
GGML_ASSERT(ggml_is_contiguous(src0->grad));
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
}
if (src1->grad) {
// noop
}
} break;
case GGML_OP_RESHAPE:
{