support broadcastable a in out_prod(a, b) and backward pass of broadcasting mul_mat(a, b)
This commit is contained in:
parent
35260f7d74
commit
aea8b6be74
1 changed files with 66 additions and 21 deletions
87
ggml.c
87
ggml.c
|
@ -4363,10 +4363,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
||||||
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
return
|
return (t0->ne[1] == t1->ne[1]) &&
|
||||||
(t0->ne[1] == t1->ne[1]) &&
|
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
|
||||||
(t0->ne[2] == t1->ne[2]) &&
|
(t1->ne[3]%t0->ne[3] == 0);
|
||||||
(t0->ne[3] == t1->ne[3]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||||
|
@ -6358,7 +6357,8 @@ struct ggml_tensor * ggml_out_prod(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
// a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
|
||||||
|
const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
||||||
|
|
||||||
result->op = GGML_OP_OUT_PROD;
|
result->op = GGML_OP_OUT_PROD;
|
||||||
|
@ -16832,36 +16832,81 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
||||||
// ds1 = t.T.dot(dt)
|
// ds1 = t.T.dot(dt)
|
||||||
|
|
||||||
// tensor.shape [m,p]
|
// tensor.shape [m,p,qq,rr]
|
||||||
// src0.shape [n,m]
|
// src0.shape [n,m,q1,r1]
|
||||||
// src1.shape [n,p]
|
// src1.shape [n,p,qq,rr]
|
||||||
|
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
|
struct ggml_tensor * s1_tg =
|
||||||
|
ggml_out_prod(ctx, // [n,m,qq,rr]
|
||||||
|
src1, // [n,p,qq,rr]
|
||||||
|
tensor->grad); // [m,p,qq,rr]
|
||||||
|
const int64_t n = s1_tg->ne[0];
|
||||||
|
const int64_t m = s1_tg->ne[1];
|
||||||
|
const int64_t qq = s1_tg->ne[2];
|
||||||
|
const int64_t rr = s1_tg->ne[3];
|
||||||
|
const int64_t q1 = src0->ne[2];
|
||||||
|
const int64_t r1 = src0->ne[3];
|
||||||
|
const int64_t nq = qq/q1;
|
||||||
|
const int64_t nr = rr/r1;
|
||||||
|
GGML_ASSERT(qq % q1 == 0);
|
||||||
|
GGML_ASSERT(rr % r1 == 0);
|
||||||
|
const bool ne2_broadcasted = qq > q1;
|
||||||
|
const bool ne3_broadcasted = rr > r1;
|
||||||
|
// handling broadcasted will create a lot of overhead.
|
||||||
|
// this could be greatly reduced if we had a ggml_sum_repetitions function.
|
||||||
|
// ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D]
|
||||||
|
// with a % A == 0, b % B == 0, etc.
|
||||||
|
// TODO: implement such function if necessary, it should be quite trivial
|
||||||
|
if (ne2_broadcasted) {
|
||||||
|
printf("%s: ne2_broadcasted\n", __func__);
|
||||||
|
// sum ne2 repetitions
|
||||||
|
// s1_tg->ne == [n,m,qq=nq*q1,rr]
|
||||||
|
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr]
|
||||||
|
s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr]
|
||||||
|
s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr]
|
||||||
|
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr]
|
||||||
|
// due to following reshape we can omit this:
|
||||||
|
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr]
|
||||||
|
}
|
||||||
|
if (ne3_broadcasted) {
|
||||||
|
printf("%s: ne3_broadcasted\n", __func__);
|
||||||
|
// sum ne3 repetitions
|
||||||
|
// s1_tg->ne == [n,m,q1,rr=nr*r1]
|
||||||
|
s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1]
|
||||||
|
s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1]
|
||||||
|
s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1]
|
||||||
|
s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1]
|
||||||
|
// due to following reshape we can omit this:
|
||||||
|
// s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1]
|
||||||
|
}
|
||||||
|
if (ne2_broadcasted || ne3_broadcasted) {
|
||||||
|
// make sure ne and n_dims match
|
||||||
|
s1_tg = ggml_reshape(ctx, s1_tg, src0);
|
||||||
|
}
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_or_set(ctx,
|
ggml_add_or_set(ctx,
|
||||||
src0->grad,
|
src0->grad, // [n,m,q1,r1]
|
||||||
ggml_out_prod(ctx, // [n,m]
|
s1_tg, // [n,m,q1,r1]
|
||||||
src1, // [n,p]
|
|
||||||
tensor->grad), // [m,p]
|
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
src1->grad =
|
src1->grad =
|
||||||
ggml_add_or_set(ctx,
|
ggml_add_or_set(ctx,
|
||||||
src1->grad,
|
src1->grad, // [n,p,qq,rr]
|
||||||
// ggml_mul_mat(ctx, // [n,p]
|
// ggml_mul_mat(ctx, // [n,p,qq,rr]
|
||||||
// ggml_cont(ctx, // [m,n]
|
// ggml_cont(ctx, // [m,n,q1,r1]
|
||||||
// ggml_transpose(ctx, src0)), // [m,n]
|
// ggml_transpose(ctx, src0)), // [m,n,q1,r1]
|
||||||
// tensor->grad), // [m,p]
|
// tensor->grad), // [m,p,qq,rr]
|
||||||
|
|
||||||
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
||||||
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
||||||
// // and then use ggml_out_prod
|
// // and then use ggml_out_prod
|
||||||
ggml_out_prod(ctx, // [n,p]
|
ggml_out_prod(ctx, // [n,p,qq,rr]
|
||||||
src0, // [n,m]
|
src0, // [n,m,q1,r1]
|
||||||
ggml_transpose(ctx, // [p,m]
|
ggml_transpose(ctx, // [p,m,qq,rr]
|
||||||
tensor->grad)), // [m,p]
|
tensor->grad)), // [m,p,qq,rr]
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue