diff --git a/ggml.c b/ggml.c index 5b1c3c79c..2a8d95ec8 100644 --- a/ggml.c +++ b/ggml.c @@ -4363,10 +4363,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[1] == t1->ne[1]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { @@ -6358,7 +6357,8 @@ struct ggml_tensor * ggml_out_prod( is_node = true; } - const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = GGML_OP_OUT_PROD; @@ -16832,36 +16832,81 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix // ds1 = t.T.dot(dt) - // tensor.shape [m,p] - // src0.shape [n,m] - // src1.shape [n,p] + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] // necessary for llama if (src0->grad) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t n = s1_tg->ne[0]; + const int64_t m = s1_tg->ne[1]; + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const int64_t nq = qq/q1; + const int64_t nr = rr/r1; + GGML_ASSERT(qq % q1 == 0); + GGML_ASSERT(rr % r1 == 0); + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + // handling broadcasted will create a lot of overhead. + // this could be greatly reduced if we had a ggml_sum_repetitions function. + // ggml_sum_repetitions(ctx, u with ne=[a,b,c,d], v with ne=[A,B,C,D]) -> ne=[A,B,C,D] + // with a % A == 0, b % B == 0, etc. + // TODO: implement such function if necessary, it should be quite trivial + if (ne2_broadcasted) { + printf("%s: ne2_broadcasted\n", __func__); + // sum ne2 repetitions + // s1_tg->ne == [n,m,qq=nq*q1,rr] + s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, nq, q1, rr); // [n*m,nq,q1,rr] + s1_tg = ggml_transpose(ctx, s1_tg); // [nq,n*m,q1,rr] + s1_tg = ggml_cont(ctx, s1_tg); // [nq,n*m,q1,rr] + s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,rr] + // due to following reshape we can omit this: + // s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, rr); // [n,m,q1,rr] + } + if (ne3_broadcasted) { + printf("%s: ne3_broadcasted\n", __func__); + // sum ne3 repetitions + // s1_tg->ne == [n,m,q1,rr=nr*r1] + s1_tg = ggml_reshape_4d(ctx, s1_tg, n*m, q1, nr, r1); // [n*m,q1,nr,r1] + s1_tg = ggml_permute(ctx, s1_tg, 1, 2, 0, 3); // [nr,n*m,q1,r1] + s1_tg = ggml_cont(ctx, s1_tg); // [nr,n*m,q1,r1] + s1_tg = ggml_sum_rows(ctx, s1_tg); // [1,n*m,q1,r1] + // due to following reshape we can omit this: + // s1_tg = ggml_reshape_4d(ctx, s1_tg, n, m, q1, r1); // [n,m,q1,r1] + } + if (ne2_broadcasted || ne3_broadcasted) { + // make sure ne and n_dims match + s1_tg = ggml_reshape(ctx, s1_tg, src0); + } src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_out_prod(ctx, // [n,m] - src1, // [n,p] - tensor->grad), // [m,p] + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] zero_table); } if (src1->grad) { src1->grad = ggml_add_or_set(ctx, - src1->grad, - // ggml_mul_mat(ctx, // [n,p] - // ggml_cont(ctx, // [m,n] - // ggml_transpose(ctx, src0)), // [m,n] - // tensor->grad), // [m,p] + src1->grad, // [n,p,qq,rr] + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] // // when src0 is bigger than tensor->grad (this is mostly the case in llama), // // avoid transpose of src0, rather transpose smaller tensor->grad // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p] - src0, // [n,m] - ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] zero_table); } } break;