simplify broadcasting mul_mat backward using ggml_repeat_back
This commit is contained in:
parent
d3aaf0876a
commit
d3f1b438a8
1 changed files with 2 additions and 35 deletions
37
ggml.c
37
ggml.c
|
@ -16842,48 +16842,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
ggml_out_prod(ctx, // [n,m,qq,rr]
|
ggml_out_prod(ctx, // [n,m,qq,rr]
|
||||||
src1, // [n,p,qq,rr]
|
src1, // [n,p,qq,rr]
|
||||||
tensor->grad); // [m,p,qq,rr]
|
tensor->grad); // [m,p,qq,rr]
|
||||||
const int64_t n = s1_tg->ne[0];
|
|
||||||
const int64_t m = s1_tg->ne[1];
|
|
||||||
const int64_t qq = s1_tg->ne[2];
|
const int64_t qq = s1_tg->ne[2];
|
||||||
const int64_t rr = s1_tg->ne[3];
|
const int64_t rr = s1_tg->ne[3];
|
||||||
const int64_t q1 = src0->ne[2];
|
const int64_t q1 = src0->ne[2];
|
||||||
const int64_t r1 = src0->ne[3];
|
const int64_t r1 = src0->ne[3];
|
||||||
const int64_t nq = qq/q1;
|
|
||||||
const int64_t nr = rr/r1;
|
|
||||||
GGML_ASSERT(qq % q1 == 0);
|
|
||||||
GGML_ASSERT(rr % r1 == 0);
|
|
||||||
const bool ne2_broadcasted = qq > q1;
|
const bool ne2_broadcasted = qq > q1;
|
||||||
const bool ne3_broadcasted = rr > r1;
|
const bool ne3_broadcasted = rr > r1;
|
||||||
// handling broadcasted will create a lot of overhead.
|
|
||||||
// this could be greatly reduced if we had a ggml_sum_repetitions function.
|
|
||||||
// ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D]
|
|
||||||
// with a % A == 0, b % B == 0, etc.
|
|
||||||
// TODO: implement such function if necessary, it should be quite trivial
|
|
||||||
if (ne2_broadcasted) {
|
|
||||||
printf("%s: ne2_broadcasted\n", __func__);
|
|
||||||
// sum ne2 repetitions
|
|
||||||
// s1_tg->ne == [n,m,qq=nq*q1,rr]
|
|
||||||
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr]
|
|
||||||
s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr]
|
|
||||||
s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr]
|
|
||||||
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr]
|
|
||||||
// due to following reshape we can omit this:
|
|
||||||
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr]
|
|
||||||
}
|
|
||||||
if (ne3_broadcasted) {
|
|
||||||
printf("%s: ne3_broadcasted\n", __func__);
|
|
||||||
// sum ne3 repetitions
|
|
||||||
// s1_tg->ne == [n,m,q1,rr=nr*r1]
|
|
||||||
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1]
|
|
||||||
s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1]
|
|
||||||
s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1]
|
|
||||||
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1]
|
|
||||||
// due to following reshape we can omit this:
|
|
||||||
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1]
|
|
||||||
}
|
|
||||||
if (ne2_broadcasted || ne3_broadcasted) {
|
if (ne2_broadcasted || ne3_broadcasted) {
|
||||||
// make sure ne and n_dims match
|
// sum broadcast repetitions of s1_tg into shape of src0
|
||||||
s1_tg = ggml_reshape(ctx, s1_tg, src0);
|
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
|
||||||
}
|
}
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_or_set(ctx,
|
ggml_add_or_set(ctx,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue