successfully test backward pass of repeat

This commit is contained in:
xaedes 2023-05-01 01:11:41 +02:00
parent 8b5b2f089e
commit 72bcfb50c8
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -478,6 +478,29 @@ int main(int argc, const char ** argv) {
}
}
// repeat
{
int64_t ne2[4];
get_random_dims(ne2, 4);
ne2[0] = ne[0] * ne2[0];
ne2[1] = ne[1] * ne2[1];
ne2[2] = 1;
ne2[3] = 1;
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
}
}
// abs (finite differences do not work)
//{
// const int nargs = 1;