bug fix for cpy backward pass
This commit is contained in:
parent
7571147242
commit
0ea8201c86
1 changed files with 5 additions and 1 deletions
6
ggml.c
6
ggml.c
|
@ -12809,11 +12809,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
{
|
{
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
|
// cpy overwrites value of src1 by src0 and returns view(src1)
|
||||||
|
// the overwriting is mathematically equivalent to:
|
||||||
|
// tensor = src0 * 1 + src1 * 0
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
|
// dsrc0 = dtensor * 1
|
||||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
|
// dsrc1 = dtensor * 0 -> noop
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue