implement ggml_cont backward pass
This commit is contained in:
parent
02d3fd0894
commit
3d21f2646e
1 changed files with 9 additions and 2 deletions
11
ggml.c
11
ggml.c
|
@ -5822,7 +5822,6 @@ struct ggml_tensor * ggml_cont_impl(
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (!inplace && a->grad) {
|
if (!inplace && a->grad) {
|
||||||
// TODO: implement backward
|
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13188,7 +13187,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
// same as cpy
|
||||||
|
if (src0->grad) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
||||||
|
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||||
|
}
|
||||||
|
if (src1->grad) {
|
||||||
|
// noop
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue