fix backward pass for repeat
requires ggml_sum_rows
This commit is contained in:
parent
ba62c79bd5
commit
8b5b2f089e
1 changed files with 33 additions and 3 deletions
36
ggml.c
36
ggml.c
|
@ -13092,12 +13092,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
// TODO: is this really correct?
|
||||
// i think tensor->grad must be reshaped to [*src0->ne[[0,1,2]],-1] and then summed along last axis
|
||||
GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
|
||||
const int nc = tensor->ne[0];
|
||||
const int nr = tensor->ne[1];
|
||||
const int nc0 = src0->ne[0];
|
||||
const int nr0 = src0->ne[1];
|
||||
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
// tensor->grad [nc,nr,1,1]
|
||||
// reshape [nc0,nc/nc0,nr0,nr/nr0]
|
||||
// permute [nc0,nr0,nc/nc0,nr/nr0]
|
||||
// substitute [nc0,nr0,ncr,nrr]
|
||||
// reshape [nc0*nr0,ncr*nrr,1,1]
|
||||
// transpose [ncr*nrr,nc0*nr0,1,1]
|
||||
// sum rows [1,nc0*nr0,1,1]
|
||||
// transpose [nc0*nr0,1,1]
|
||||
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
||||
// add to src0->grad
|
||||
|
||||
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
||||
|
||||
struct ggml_tensor* F00 = tensor->grad;
|
||||
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
||||
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
||||
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
||||
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
||||
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
||||
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
||||
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
||||
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
||||
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
||||
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
||||
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_sum(ctx, tensor->grad),
|
||||
F10,
|
||||
inplace);
|
||||
}
|
||||
} break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue