From 72bcfb50c8ca77f024b1bd4af57c5a7b8f07e5a8 Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 1 May 2023 01:11:41 +0200 Subject: [PATCH] successfully test backward pass of repeat --- tests/test-grad0.c | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test-grad0.c b/tests/test-grad0.c index ec9df5564..aa8d7a97f 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -478,6 +478,29 @@ int main(int argc, const char ** argv) { } } + // repeat + { + int64_t ne2[4]; + get_random_dims(ne2, 4); + + ne2[0] = ne[0] * ne2[0]; + ne2[1] = ne[1] * ne2[1]; + ne2[2] = 1; + ne2[3] = 1; + + const int nargs = 1; + for (int ndims = 1; ndims <= 2; ++ndims) { + x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); + x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); + ggml_set_param(ctx0, x[0]); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); + + check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + } + + } + // abs (finite differences do not work) //{ // const int nargs = 1;