improve performance of mul_mat backward pass
avoid transpose by using mul_mat with swapped arguments
This commit is contained in:
parent
1f2b76de01
commit
c054079fb8
1 changed files with 16 additions and 9 deletions
21
ggml.c
21
ggml.c
|
@ -13050,15 +13050,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// src1, // [n,p]
|
// src1, // [n,p]
|
||||||
// tensor->grad), // [m,p]
|
// tensor->grad), // [m,p]
|
||||||
// for now just using A*B==(B.T*A.T).T
|
// for now just using A*B==(B.T*A.T).T
|
||||||
ggml_cont(ctx, // [n,m]
|
ggml_mul_mat(ctx, // [n,m]
|
||||||
ggml_transpose(ctx, // [n,m]
|
|
||||||
ggml_mul_mat(ctx, // [m,n]
|
|
||||||
ggml_cont(ctx, // [p,m]
|
|
||||||
ggml_transpose(ctx, // [p,m]
|
|
||||||
tensor->grad)), // [m,p]
|
|
||||||
ggml_cont(ctx, // [p,n]
|
ggml_cont(ctx, // [p,n]
|
||||||
ggml_transpose(ctx, // [p,n]
|
ggml_transpose(ctx, // [p,n]
|
||||||
src1))))), // [n,p]
|
src1)), // [n,p]
|
||||||
|
ggml_cont(ctx, // [p,m]
|
||||||
|
ggml_transpose(ctx, // [p,m]
|
||||||
|
tensor->grad))), // [m,p]
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
@ -13070,6 +13068,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
ggml_cont(ctx, // [m,n]
|
ggml_cont(ctx, // [m,n]
|
||||||
ggml_transpose(ctx, src0)), // [m,n]
|
ggml_transpose(ctx, src0)), // [m,n]
|
||||||
tensor->grad), // [m,p]
|
tensor->grad), // [m,p]
|
||||||
|
|
||||||
|
// // when src0 is bigger than tensor->grad (this is the case in llama),
|
||||||
|
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
||||||
|
// // and then use ggml_out_prod
|
||||||
|
// ggml_out_prod(ctx, // [n,p]
|
||||||
|
// src0, // [n,m]
|
||||||
|
// ggml_cont(ctx, // [p,m]
|
||||||
|
// ggml_transpose(ctx, // [p,m]
|
||||||
|
// tensor->grad)), // [m,p]
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue