From 0ea8201c8677a9eb93173d72f67cd8cdbb002fac Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 26 Apr 2023 20:14:33 +0200 Subject: [PATCH] bug fix for cpy backward pass --- ggml.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 723e55d99..cde938eae 100644 --- a/ggml.c +++ b/ggml.c @@ -12809,11 +12809,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_CPY: { // necessary for llama + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 if (src0->grad) { + // dsrc0 = dtensor * 1 src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } if (src1->grad) { - src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + // dsrc1 = dtensor * 0 -> noop } } break; case GGML_OP_CONT: