align shape annotations
This commit is contained in:
parent
fea42be47a
commit
93106504fd
1 changed files with 14 additions and 12 deletions
26
ggml.c
26
ggml.c
|
@ -12857,10 +12857,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
|
||||||
// ds1 = t.T.dot(dt)
|
// ds1 = t.T.dot(dt)
|
||||||
|
|
||||||
// tensor.T == (src0 @ src1.T).T
|
|
||||||
// tensor.shape [m,p]
|
// tensor.shape [m,p]
|
||||||
// src0.shape [n,m]
|
// src0.shape [n,m]
|
||||||
// src1.shape [n,p]
|
// src1.shape [n,p]
|
||||||
|
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
|
@ -12870,14 +12869,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
src0->grad,
|
src0->grad,
|
||||||
// ds0 = dt.dot(s1.T)
|
// ds0 = dt.dot(s1.T)
|
||||||
// ggml_out_prod(ctx, // [n,m]
|
// ggml_out_prod(ctx, // [n,m]
|
||||||
// src1, // [n,p]
|
// src1, // [n,p]
|
||||||
// tensor->grad), // [m,p]
|
// tensor->grad), // [m,p]
|
||||||
// for now just using A*B==(B.T*A.T).T
|
// for now just using A*B==(B.T*A.T).T
|
||||||
ggml_cont(ctx, // [n,m] not necessary TODO: investigate influence on speed
|
ggml_cont(ctx, // [n,m]
|
||||||
ggml_transpose(ctx, // [n,m]
|
ggml_transpose(ctx, // [n,m]
|
||||||
ggml_mul_mat(ctx, // [m,n]
|
ggml_mul_mat(ctx, // [m,n]
|
||||||
ggml_cont(ctx, ggml_transpose(ctx, tensor->grad)), // [p,m]
|
ggml_cont(ctx, // [p,m]
|
||||||
ggml_cont(ctx, ggml_transpose(ctx, src1))))), // [p,n]
|
ggml_transpose(ctx, tensor->grad)), // [p,m]
|
||||||
|
ggml_cont(ctx, // [p,n]
|
||||||
|
ggml_transpose(ctx, src1))))), // [p,n]
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
@ -12885,9 +12886,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
ggml_add_impl(ctx,
|
ggml_add_impl(ctx,
|
||||||
src1->grad,
|
src1->grad,
|
||||||
// ds1 = s0.T.dot(dt):
|
// ds1 = s0.T.dot(dt):
|
||||||
ggml_mul_mat(ctx, // [n,p]
|
ggml_mul_mat(ctx, // [n,p]
|
||||||
ggml_cont(ctx, ggml_transpose(ctx, src0)), // [m,n]
|
ggml_cont(ctx, // [m,n]
|
||||||
tensor->grad), // [m,p]
|
ggml_transpose(ctx, src0)), // [m,n]
|
||||||
|
tensor->grad), // [m,p]
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue