improve performance of mul_mat backward pass

avoid transpose by using mul_mat with swapped arguments
This commit is contained in:
xaedes 2023-05-14 20:56:50 +02:00
parent 1f2b76de01
commit c054079fb8
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

25
ggml.c
View file

@ -13050,15 +13050,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// src1, // [n,p]
// tensor->grad), // [m,p]
// for now just using A*B==(B.T*A.T).T
ggml_cont(ctx, // [n,m]
ggml_transpose(ctx, // [n,m]
ggml_mul_mat(ctx, // [m,n]
ggml_cont(ctx, // [p,m]
ggml_transpose(ctx, // [p,m]
tensor->grad)), // [m,p]
ggml_cont(ctx, // [p,n]
ggml_transpose(ctx, // [p,n]
src1))))), // [n,p]
ggml_mul_mat(ctx, // [n,m]
ggml_cont(ctx, // [p,n]
ggml_transpose(ctx, // [p,n]
src1)), // [n,p]
ggml_cont(ctx, // [p,m]
ggml_transpose(ctx, // [p,m]
tensor->grad))), // [m,p]
inplace);
}
if (src1->grad) {
@ -13070,6 +13068,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_cont(ctx, // [m,n]
ggml_transpose(ctx, src0)), // [m,n]
tensor->grad), // [m,p]
// // when src0 is bigger than tensor->grad (this is the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_out_prod
// ggml_out_prod(ctx, // [n,p]
// src0, // [n,m]
// ggml_cont(ctx, // [p,m]
// ggml_transpose(ctx, // [p,m]
// tensor->grad)), // [m,p]
inplace);
}
} break;