From c054079fb81a25acf941c9c27c19087c0eaed632 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 14 May 2023 20:56:50 +0200 Subject: [PATCH] improve performance of mul_mat backward pass avoid transpose by using mul_mat with swapped arguments --- ggml.c | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index 06e3feea0..9a0a07aa5 100644 --- a/ggml.c +++ b/ggml.c @@ -13050,15 +13050,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // src1, // [n,p] // tensor->grad), // [m,p] // for now just using A*B==(B.T*A.T).T - ggml_cont(ctx, // [n,m] - ggml_transpose(ctx, // [n,m] - ggml_mul_mat(ctx, // [m,n] - ggml_cont(ctx, // [p,m] - ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] - ggml_cont(ctx, // [p,n] - ggml_transpose(ctx, // [p,n] - src1))))), // [n,p] + ggml_mul_mat(ctx, // [n,m] + ggml_cont(ctx, // [p,n] + ggml_transpose(ctx, // [p,n] + src1)), // [n,p] + ggml_cont(ctx, // [p,m] + ggml_transpose(ctx, // [p,m] + tensor->grad))), // [m,p] inplace); } if (src1->grad) { @@ -13070,6 +13068,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_cont(ctx, // [m,n] ggml_transpose(ctx, src0)), // [m,n] tensor->grad), // [m,p] + + // // when src0 is bigger than tensor->grad (this is the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use ggml_out_prod + // ggml_out_prod(ctx, // [n,p] + // src0, // [n,m] + // ggml_cont(ctx, // [p,m] + // ggml_transpose(ctx, // [p,m] + // tensor->grad)), // [m,p] inplace); } } break;