fix backward pass for repeat

requires ggml_sum_rows
This commit is contained in:
xaedes 2023-05-01 01:11:12 +02:00
parent ba62c79bd5
commit 8b5b2f089e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

36
ggml.c
View file

@ -13092,12 +13092,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
// TODO: is this really correct?
// i think tensor->grad must be reshaped to [*src0->ne[[0,1,2]],-1] and then summed along last axis
GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
const int nc = tensor->ne[0];
const int nr = tensor->ne[1];
const int nc0 = src0->ne[0];
const int nr0 = src0->ne[1];
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
// tensor->grad [nc,nr,1,1]
// reshape [nc0,nc/nc0,nr0,nr/nr0]
// permute [nc0,nr0,nc/nc0,nr/nr0]
// substitute [nc0,nr0,ncr,nrr]
// reshape [nc0*nr0,ncr*nrr,1,1]
// transpose [ncr*nrr,nc0*nr0,1,1]
// sum rows [1,nc0*nr0,1,1]
// transpose [nc0*nr0,1,1]
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
// add to src0->grad
int64_t ne[4] = {nc0,ncr,nr0,nrr};
struct ggml_tensor* F00 = tensor->grad;
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_sum(ctx, tensor->grad),
F10,
inplace);
}
} break;