support broadcastable a in out_prod(a, b) and backward pass of broadcasting mul_mat(a, b)

This commit is contained in:
xaedes 2023-09-09 18:37:45 +02:00
parent 35260f7d74
commit aea8b6be74
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

87
ggml.c
View file

@ -4363,10 +4363,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
(t0->ne[1] == t1->ne[1]) &&
(t0->ne[2] == t1->ne[2]) &&
(t0->ne[3] == t1->ne[3]);
return (t0->ne[1] == t1->ne[1]) &&
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t1->ne[3]%t0->ne[3] == 0);
}
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@ -6358,7 +6357,8 @@ struct ggml_tensor * ggml_out_prod(
is_node = true;
}
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
// a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_OUT_PROD;
@ -16832,36 +16832,81 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
// ds1 = t.T.dot(dt)
// tensor.shape [m,p]
// src0.shape [n,m]
// src1.shape [n,p]
// tensor.shape [m,p,qq,rr]
// src0.shape [n,m,q1,r1]
// src1.shape [n,p,qq,rr]
// necessary for llama
if (src0->grad) {
struct ggml_tensor * s1_tg =
ggml_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr]
tensor->grad); // [m,p,qq,rr]
const int64_t n = s1_tg->ne[0];
const int64_t m = s1_tg->ne[1];
const int64_t qq = s1_tg->ne[2];
const int64_t rr = s1_tg->ne[3];
const int64_t q1 = src0->ne[2];
const int64_t r1 = src0->ne[3];
const int64_t nq = qq/q1;
const int64_t nr = rr/r1;
GGML_ASSERT(qq % q1 == 0);
GGML_ASSERT(rr % r1 == 0);
const bool ne2_broadcasted = qq > q1;
const bool ne3_broadcasted = rr > r1;
// handling broadcasted will create a lot of overhead.
// this could be greatly reduced if we had a ggml_sum_repetitions function.
// ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D]
// with a % A == 0, b % B == 0, etc.
// TODO: implement such function if necessary, it should be quite trivial
if (ne2_broadcasted) {
printf("%s: ne2_broadcasted\n", __func__);
// sum ne2 repetitions
// s1_tg->ne == [n,m,qq=nq*q1,rr]
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr]
s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr]
s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr]
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr]
// due to following reshape we can omit this:
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr]
}
if (ne3_broadcasted) {
printf("%s: ne3_broadcasted\n", __func__);
// sum ne3 repetitions
// s1_tg->ne == [n,m,q1,rr=nr*r1]
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1]
s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1]
s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1]
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1]
// due to following reshape we can omit this:
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1]
}
if (ne2_broadcasted || ne3_broadcasted) {
// make sure ne and n_dims match
s1_tg = ggml_reshape(ctx, s1_tg, src0);
}
src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_out_prod(ctx, // [n,m]
src1, // [n,p]
tensor->grad), // [m,p]
src0->grad, // [n,m,q1,r1]
s1_tg, // [n,m,q1,r1]
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_add_or_set(ctx,
src1->grad,
// ggml_mul_mat(ctx, // [n,p]
// ggml_cont(ctx, // [m,n]
// ggml_transpose(ctx, src0)), // [m,n]
// tensor->grad), // [m,p]
src1->grad, // [n,p,qq,rr]
// ggml_mul_mat(ctx, // [n,p,qq,rr]
// ggml_cont(ctx, // [m,n,q1,r1]
// ggml_transpose(ctx, src0)), // [m,n,q1,r1]
// tensor->grad), // [m,p,qq,rr]
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_out_prod
ggml_out_prod(ctx, // [n,p]
src0, // [n,m]
ggml_transpose(ctx, // [p,m]
tensor->grad)), // [m,p]
ggml_out_prod(ctx, // [n,p,qq,rr]
src0, // [n,m,q1,r1]
ggml_transpose(ctx, // [p,m,qq,rr]
tensor->grad)), // [m,p,qq,rr]
zero_table);
}
} break;