test both gradients of mul_mat
This commit is contained in:
parent
20e3c1d2b4
commit
9345f4c3a5
2 changed files with 3 additions and 2 deletions
2
ggml.c
2
ggml.c
|
@ -5819,7 +5819,7 @@ struct ggml_tensor * ggml_cont_impl(
|
|||
bool is_node = false;
|
||||
|
||||
if (!inplace && a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
// TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -385,7 +385,7 @@ int main(int argc, const char ** argv) {
|
|||
|
||||
// mul_mat
|
||||
{
|
||||
const int nargs = 1;
|
||||
const int nargs = 2;
|
||||
|
||||
for (int ndims = 2; ndims <= 2; ++ndims) {
|
||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
|
@ -397,6 +397,7 @@ int main(int argc, const char ** argv) {
|
|||
}
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
ggml_set_param(ctx0, x[1]);
|
||||
|
||||
struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, m);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue