diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 447944d41..c8ca4405f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1260,21 +1260,25 @@ struct test_im2col : public test_case { struct test_concat : public test_case { const ggml_type type; const std::array ne; - const int64_t b_ne2; + const int dim; + const int64_t b_ned; std::string vars() override { - return VARS_TO_STR3(type, ne, b_ne2); + return VARS_TO_STR4(type, ne, dim, b_ned); } test_concat(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - int64_t b_ne2 = 10) - : type(type), ne(ne), b_ne2(b_ne2) {} + int dim = 2, + int64_t b_ned = 10) + : type(type), ne(ne), dim(dim), b_ned(b_ned) {} ggml_tensor * build_graph(ggml_context * ctx) override { + auto b_ne = ne; + b_ne[dim] = b_ned; ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]); - ggml_tensor * out = ggml_concat(ctx, a, b, 2); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, b_ne.data()); + ggml_tensor * out = ggml_concat(ctx, a, b, dim); return out; } }; @@ -2211,8 +2215,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); - test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); + for (int dim : { 0, 1, 2, 3, }) { + test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {10, 10, 10, 10}, dim)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {10, 10, 10, 10}, dim)); + } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));