simplify broadcasting mul_mat backward using ggml_repeat_back

This commit is contained in:
xaedes 2023-09-09 18:55:18 +02:00
parent d3aaf0876a
commit d3f1b438a8
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

37
ggml.c
View file

@ -16842,48 +16842,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_out_prod(ctx, // [n,m,qq,rr] ggml_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr] src1, // [n,p,qq,rr]
tensor->grad); // [m,p,qq,rr] tensor->grad); // [m,p,qq,rr]
const int64_t n = s1_tg->ne[0];
const int64_t m = s1_tg->ne[1];
const int64_t qq = s1_tg->ne[2]; const int64_t qq = s1_tg->ne[2];
const int64_t rr = s1_tg->ne[3]; const int64_t rr = s1_tg->ne[3];
const int64_t q1 = src0->ne[2]; const int64_t q1 = src0->ne[2];
const int64_t r1 = src0->ne[3]; const int64_t r1 = src0->ne[3];
const int64_t nq = qq/q1;
const int64_t nr = rr/r1;
GGML_ASSERT(qq % q1 == 0);
GGML_ASSERT(rr % r1 == 0);
const bool ne2_broadcasted = qq > q1; const bool ne2_broadcasted = qq > q1;
const bool ne3_broadcasted = rr > r1; const bool ne3_broadcasted = rr > r1;
// handling broadcasted will create a lot of overhead.
// this could be greatly reduced if we had a ggml_sum_repetitions function.
// ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D]
// with a % A == 0, b % B == 0, etc.
// TODO: implement such function if necessary, it should be quite trivial
if (ne2_broadcasted) {
printf("%s: ne2_broadcasted\n", __func__);
// sum ne2 repetitions
// s1_tg->ne == [n,m,qq=nq*q1,rr]
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr]
s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr]
s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr]
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr]
// due to following reshape we can omit this:
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr]
}
if (ne3_broadcasted) {
printf("%s: ne3_broadcasted\n", __func__);
// sum ne3 repetitions
// s1_tg->ne == [n,m,q1,rr=nr*r1]
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1]
s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1]
s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1]
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1]
// due to following reshape we can omit this:
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1]
}
if (ne2_broadcasted || ne3_broadcasted) { if (ne2_broadcasted || ne3_broadcasted) {
// make sure ne and n_dims match // sum broadcast repetitions of s1_tg into shape of src0
s1_tg = ggml_reshape(ctx, s1_tg, src0); s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
} }
src0->grad = src0->grad =
ggml_add_or_set(ctx, ggml_add_or_set(ctx,