From 7a15a8370c44fceab5d9a33435f0df0bc5b910fe Mon Sep 17 00:00:00 2001 From: xaedes Date: Sat, 6 May 2023 17:37:51 +0200 Subject: [PATCH] implement backward pass for ggml_sum_rows, necessary for cross entropy loss --- ggml.c | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index c802be1f2..95c273be3 100644 --- a/ggml.c +++ b/ggml.c @@ -13280,7 +13280,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_SUM_ROWS: { - GGML_ASSERT(false); // TODO: implement + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, + tensor->grad, + src0->grad), + inplace); + } } break; case GGML_OP_MEAN: {