replace inefficient repeat backward pass with dedicated repeat_back operation

This commit is contained in:
xaedes 2023-05-28 18:00:17 +02:00
parent c47df09842
commit 05cb629c8e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 150 additions and 39 deletions

183
ggml.c
View file

@ -3297,6 +3297,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"SUM_ROWS",
"MEAN",
"REPEAT",
"REPEAT_BACK",
"ABS",
"SGN",
"NEG",
@ -3340,7 +3341,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_BINARY",
};
static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -3359,6 +3360,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"Σx_k",
"Σx/n",
"repeat(x)",
"repeat_back(x)",
"abs(x)",
"sgn(x)",
"-x",
@ -3402,7 +3404,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};
static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -4790,6 +4792,34 @@ struct ggml_tensor * ggml_repeat(
return result;
}
// ggml_repeat_back
struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_can_repeat(b, a));
bool is_node = false;
if (a->grad) {
is_node = true;
}
if (ggml_are_same_shape(a, b) && !is_node) {
return a;
}
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
result->op = GGML_OP_REPEAT_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
// ggml_abs
struct ggml_tensor * ggml_abs_impl(
@ -8430,6 +8460,99 @@ static void ggml_compute_forward_repeat(
}
}
// ggml_compute_forward_repeat_back
static void ggml_compute_forward_repeat_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_can_repeat(dst, src0));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];
// guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne00/ne0);
const int nr1 = (int)(ne01/ne1);
const int nr2 = (int)(ne02/ne2);
const int nr3 = (int)(ne03/ne3);
// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
if (ggml_is_contiguous(dst)) {
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
} else {
for (int k3 = 0; k3 < ne3; k3++) {
for (int k2 = 0; k2 < ne2; k2++) {
for (int k1 = 0; k1 < ne1; k1++) {
ggml_vec_set_f32(ne0,
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
0);
}
}
}
}
// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne3; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne2; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne1; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
ggml_vec_acc_f32(ne0,
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
}
}
}
}
}
}
}
}
static void ggml_compute_forward_repeat_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_repeat_back_f32(params, src0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_abs
static void ggml_compute_forward_abs_f32(
@ -12770,6 +12893,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_repeat(params, tensor->src0, tensor);
} break;
case GGML_OP_REPEAT_BACK:
{
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
} break;
case GGML_OP_ABS:
{
ggml_compute_forward_abs(params, tensor->src0, tensor);
@ -13113,43 +13240,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
const int nc = tensor->ne[0];
const int nr = tensor->ne[1];
const int nc0 = src0->ne[0];
const int nr0 = src0->ne[1];
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
// tensor->grad [nc,nr,1,1]
// reshape [nc0,nc/nc0,nr0,nr/nr0]
// permute [nc0,nr0,nc/nc0,nr/nr0]
// substitute [nc0,nr0,ncr,nrr]
// reshape [nc0*nr0,ncr*nrr,1,1]
// transpose [ncr*nrr,nc0*nr0,1,1]
// sum rows [1,nc0*nr0,1,1]
// transpose [nc0*nr0,1,1]
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
// add to src0->grad
int64_t ne[4] = {nc0,ncr,nr0,nrr};
struct ggml_tensor* F00 = tensor->grad;
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
src0->grad =
ggml_add_impl(ctx,
src0->grad,
F10,
inplace);
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_repeat_back(ctx, tensor->grad, src0->grad),
inplace);
}
} break;
case GGML_OP_REPEAT_BACK:
{
if (src0->grad) {
// TODO: test this
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_repeat(ctx, tensor->grad, src0->grad),
inplace);
}
} break;
case GGML_OP_ABS:
@ -13941,6 +14045,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ABS:
case GGML_OP_SGN:
case GGML_OP_NEG:

6
ggml.h
View file

@ -279,6 +279,7 @@ extern "C" {
GGML_OP_SUM_ROWS,
GGML_OP_MEAN,
GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK,
GGML_OP_ABS,
GGML_OP_SGN,
GGML_OP_NEG,
@ -596,6 +597,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,
struct ggml_tensor * a);