align shape annotations

This commit is contained in:
xaedes 2023-04-27 00:21:31 +02:00
parent fea42be47a
commit 93106504fd
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

12
ggml.c
View file

@ -12857,7 +12857,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
// ds1 = t.T.dot(dt) // ds1 = t.T.dot(dt)
// tensor.T == (src0 @ src1.T).T
// tensor.shape [m,p] // tensor.shape [m,p]
// src0.shape [n,m] // src0.shape [n,m]
// src1.shape [n,p] // src1.shape [n,p]
@ -12873,11 +12872,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// src1, // [n,p] // src1, // [n,p]
// tensor->grad), // [m,p] // tensor->grad), // [m,p]
// for now just using A*B==(B.T*A.T).T // for now just using A*B==(B.T*A.T).T
ggml_cont(ctx, // [n,m] not necessary TODO: investigate influence on speed ggml_cont(ctx, // [n,m]
ggml_transpose(ctx, // [n,m] ggml_transpose(ctx, // [n,m]
ggml_mul_mat(ctx, // [m,n] ggml_mul_mat(ctx, // [m,n]
ggml_cont(ctx, ggml_transpose(ctx, tensor->grad)), // [p,m] ggml_cont(ctx, // [p,m]
ggml_cont(ctx, ggml_transpose(ctx, src1))))), // [p,n] ggml_transpose(ctx, tensor->grad)), // [p,m]
ggml_cont(ctx, // [p,n]
ggml_transpose(ctx, src1))))), // [p,n]
inplace); inplace);
} }
if (src1->grad) { if (src1->grad) {
@ -12886,7 +12887,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad, src1->grad,
// ds1 = s0.T.dot(dt): // ds1 = s0.T.dot(dt):
ggml_mul_mat(ctx, // [n,p] ggml_mul_mat(ctx, // [n,p]
ggml_cont(ctx, ggml_transpose(ctx, src0)), // [m,n] ggml_cont(ctx, // [m,n]
ggml_transpose(ctx, src0)), // [m,n]
tensor->grad), // [m,p] tensor->grad), // [m,p]
inplace); inplace);
} }