implement backward pass for ggml_sum_rows, necessary for cross entropy loss

This commit is contained in:
xaedes 2023-05-06 17:37:51 +02:00
parent 5724628d31
commit 7a15a8370c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

10
ggml.c
View file

@ -13280,7 +13280,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_SUM_ROWS:
{
GGML_ASSERT(false); // TODO: implement
if (src0->grad) {
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_repeat(ctx,
tensor->grad,
src0->grad),
inplace);
}
} break;
case GGML_OP_MEAN:
{