implement backward pass for ggml_sum_rows, necessary for cross entropy loss
This commit is contained in:
parent
5724628d31
commit
7a15a8370c
1 changed files with 9 additions and 1 deletions
10
ggml.c
10
ggml.c
|
@ -13280,7 +13280,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: implement
|
if (src0->grad) {
|
||||||
|
src0->grad =
|
||||||
|
ggml_add_impl(ctx,
|
||||||
|
src0->grad,
|
||||||
|
ggml_repeat(ctx,
|
||||||
|
tensor->grad,
|
||||||
|
src0->grad),
|
||||||
|
inplace);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue