From d3f1b438a84bda322ee9d2c674b7c10393f1d940 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sat, 9 Sep 2023 18:55:18 +0200 Subject: [PATCH] simplify broadcasting mul_mat backward using ggml_repeat_back --- ggml.c | 37 ++----------------------------------- 1 file changed, 2 insertions(+), 35 deletions(-) diff --git a/ggml.c b/ggml.c index 2a8d95ec8..f6dca255f 100644 --- a/ggml.c +++ b/ggml.c @@ -16842,48 +16842,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] tensor->grad); // [m,p,qq,rr] - const int64_t n = s1_tg->ne[0]; - const int64_t m = s1_tg->ne[1]; const int64_t qq = s1_tg->ne[2]; const int64_t rr = s1_tg->ne[3]; const int64_t q1 = src0->ne[2]; const int64_t r1 = src0->ne[3]; - const int64_t nq = qq/q1; - const int64_t nr = rr/r1; - GGML_ASSERT(qq % q1 == 0); - GGML_ASSERT(rr % r1 == 0); const bool ne2_broadcasted = qq > q1; const bool ne3_broadcasted = rr > r1; - // handling broadcasted will create a lot of overhead. - // this could be greatly reduced if we had a ggml_sum_repetitions function. - // ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D] - // with a % A == 0, b % B == 0, etc. - // TODO: implement such function if necessary, it should be quite trivial - if (ne2_broadcasted) { - printf("%s: ne2_broadcasted\n", __func__); - // sum ne2 repetitions - // s1_tg->ne == [n,m,qq=nq*q1,rr] - s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr] - s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr] - s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr] - s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr] - // due to following reshape we can omit this: - // s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr] - } - if (ne3_broadcasted) { - printf("%s: ne3_broadcasted\n", __func__); - // sum ne3 repetitions - // s1_tg->ne == [n,m,q1,rr=nr*r1] - s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1] - s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1] - s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1] - s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1] - // due to following reshape we can omit this: - // s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1] - } if (ne2_broadcasted || ne3_broadcasted) { - // make sure ne and n_dims match - s1_tg = ggml_reshape(ctx, s1_tg, src0); + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); } src0->grad = ggml_add_or_set(ctx,