replace inefficient repeat backward pass with dedicated repeat_back operation
This commit is contained in:
parent
c47df09842
commit
05cb629c8e
2 changed files with 150 additions and 39 deletions
183
ggml.c
183
ggml.c
|
@ -3297,6 +3297,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"SUM_ROWS",
|
||||
"MEAN",
|
||||
"REPEAT",
|
||||
"REPEAT_BACK",
|
||||
"ABS",
|
||||
"SGN",
|
||||
"NEG",
|
||||
|
@ -3340,7 +3341,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"MAP_BINARY",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
|
||||
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -3359,6 +3360,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"Σx_k",
|
||||
"Σx/n",
|
||||
"repeat(x)",
|
||||
"repeat_back(x)",
|
||||
"abs(x)",
|
||||
"sgn(x)",
|
||||
"-x",
|
||||
|
@ -3402,7 +3404,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"f(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
|
||||
static_assert(GGML_OP_COUNT == 53, "GGML_OP_COUNT != 53");
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
@ -4790,6 +4792,34 @@ struct ggml_tensor * ggml_repeat(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_repeat_back
|
||||
|
||||
struct ggml_tensor * ggml_repeat_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_can_repeat(b, a));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
if (ggml_are_same_shape(a, b) && !is_node) {
|
||||
return a;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
||||
|
||||
result->op = GGML_OP_REPEAT_BACK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_abs
|
||||
|
||||
struct ggml_tensor * ggml_abs_impl(
|
||||
|
@ -8430,6 +8460,99 @@ static void ggml_compute_forward_repeat(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_repeat_back
|
||||
|
||||
static void ggml_compute_forward_repeat_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(params->ith == 0);
|
||||
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nr0 = (int)(ne00/ne0);
|
||||
const int nr1 = (int)(ne01/ne1);
|
||||
const int nr2 = (int)(ne02/ne2);
|
||||
const int nr3 = (int)(ne03/ne3);
|
||||
|
||||
// TODO: support for transposed / permuted tensors
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
if (ggml_is_contiguous(dst)) {
|
||||
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
||||
} else {
|
||||
for (int k3 = 0; k3 < ne3; k3++) {
|
||||
for (int k2 = 0; k2 < ne2; k2++) {
|
||||
for (int k1 = 0; k1 < ne1; k1++) {
|
||||
ggml_vec_set_f32(ne0,
|
||||
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: maybe this is not optimal?
|
||||
for (int i3 = 0; i3 < nr3; i3++) {
|
||||
for (int k3 = 0; k3 < ne3; k3++) {
|
||||
for (int i2 = 0; i2 < nr2; i2++) {
|
||||
for (int k2 = 0; k2 < ne2; k2++) {
|
||||
for (int i1 = 0; i1 < nr1; i1++) {
|
||||
for (int k1 = 0; k1 < ne1; k1++) {
|
||||
for (int i0 = 0; i0 < nr0; i0++) {
|
||||
ggml_vec_acc_f32(ne0,
|
||||
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
||||
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_repeat_back(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_abs
|
||||
|
||||
static void ggml_compute_forward_abs_f32(
|
||||
|
@ -12770,6 +12893,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_ABS:
|
||||
{
|
||||
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
||||
|
@ -13113,43 +13240,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
|
||||
const int nc = tensor->ne[0];
|
||||
const int nr = tensor->ne[1];
|
||||
const int nc0 = src0->ne[0];
|
||||
const int nr0 = src0->ne[1];
|
||||
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
// tensor->grad [nc,nr,1,1]
|
||||
// reshape [nc0,nc/nc0,nr0,nr/nr0]
|
||||
// permute [nc0,nr0,nc/nc0,nr/nr0]
|
||||
// substitute [nc0,nr0,ncr,nrr]
|
||||
// reshape [nc0*nr0,ncr*nrr,1,1]
|
||||
// transpose [ncr*nrr,nc0*nr0,1,1]
|
||||
// sum rows [1,nc0*nr0,1,1]
|
||||
// transpose [nc0*nr0,1,1]
|
||||
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
||||
// add to src0->grad
|
||||
|
||||
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
||||
|
||||
struct ggml_tensor* F00 = tensor->grad;
|
||||
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
||||
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
||||
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
||||
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
||||
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
||||
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
||||
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
||||
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
||||
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
||||
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
||||
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
F10,
|
||||
inplace);
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
if (src0->grad) {
|
||||
// TODO: test this
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ABS:
|
||||
|
@ -13941,6 +14045,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_ABS:
|
||||
case GGML_OP_SGN:
|
||||
case GGML_OP_NEG:
|
||||
|
|
6
ggml.h
6
ggml.h
|
@ -279,6 +279,7 @@ extern "C" {
|
|||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_MEAN,
|
||||
GGML_OP_REPEAT,
|
||||
GGML_OP_REPEAT_BACK,
|
||||
GGML_OP_ABS,
|
||||
GGML_OP_SGN,
|
||||
GGML_OP_NEG,
|
||||
|
@ -596,6 +597,11 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_abs(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue