diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0d029a765..e0ac78a5f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -247,6 +247,7 @@ static std::string var_to_str(ggml_op_pool pool) { #define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i) #define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j) #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k) +#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l) // accept FLT_MAX as infinity @@ -1186,6 +1187,7 @@ struct test_pool2d : public test_case { struct test_im2col : public test_case { const ggml_type type_input; const ggml_type type_kernel; + const ggml_type dst_type; const std::array ne_input; const std::array ne_kernel; // stride @@ -1201,22 +1203,22 @@ struct test_im2col : public test_case { const bool is_2D; std::string vars() override { - return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D); + return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D); } - test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, + test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] int s0 = 1, int s1 = 1, int p0 = 1, int p1 = 1, int d0 = 1, int d1 = 1, bool is_2D = true) - : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {} + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F16); + ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type); return out; } }; @@ -1562,6 +1564,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); @@ -1694,7 +1699,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_alibi()); - test_cases.emplace_back(new test_im2col()); test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); test_cases.emplace_back(new test_concat(GGML_TYPE_I32));