diff --git a/ggml.c b/ggml.c index 89490c257..4415ce19b 100644 --- a/ggml.c +++ b/ggml.c @@ -12890,13 +12890,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // src1, // [n,p] // tensor->grad), // [m,p] // for now just using A*B==(B.T*A.T).T - ggml_cont(ctx, // [n,m] - ggml_transpose(ctx, // [n,m] - ggml_mul_mat(ctx, // [m,n] - ggml_cont(ctx, // [p,m] - ggml_transpose(ctx, tensor->grad)), // [p,m] - ggml_cont(ctx, // [p,n] - ggml_transpose(ctx, src1))))), // [p,n] + ggml_cont(ctx, // [n,m] + ggml_transpose(ctx, // [n,m] + ggml_mul_mat(ctx, // [m,n] + ggml_cont(ctx, // [p,m] + ggml_transpose(ctx, // [p,m] + tensor->grad)), // [m,p] + ggml_cont(ctx, // [p,n] + ggml_transpose(ctx, // [p,n] + src1))))), // [n,p] inplace); } if (src1->grad) {