bug fix for cpy backward pass

This commit is contained in:
xaedes 2023-04-26 20:14:33 +02:00
parent 7571147242
commit 0ea8201c86
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

6
ggml.c
View file

@ -12809,11 +12809,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_CPY:
{
// necessary for llama
// cpy overwrites value of src1 by src0 and returns view(src1)
// the overwriting is mathematically equivalent to:
// tensor = src0 * 1 + src1 * 0
if (src0->grad) {
// dsrc0 = dtensor * 1
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
}
if (src1->grad) {
src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
// dsrc1 = dtensor * 0 -> noop
}
} break;
case GGML_OP_CONT: