From 1a82788028337d114e7949ca3262ccb3869a5dd6 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Mon, 29 Jan 2024 11:21:03 +0800 Subject: [PATCH] ADD POOL2D test case in test-backend-ops.cpp --- ggml-cuda.cu | 1 + tests/test-backend-ops.cpp | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4fb5ce784..3387d54c9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -11178,6 +11178,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ROPE: case GGML_OP_ALIBI: case GGML_OP_IM2COL: + case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ac1ae8ad2..45f6fc36a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1140,6 +1140,40 @@ struct test_alibi : public test_case { } }; +// GGML_OP_POOL2D +struct test_pool2d : public test_case { + enum ggml_op_pool pool_type; + const ggml_type type_input; + const std::array ne_input; + // kernel size + const int k0; + const int k1; + // stride + const int s0; + const int s1; + // padding + const int p0; + const int p1; + + std::string vars() override { + return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1); + } + + test_pool2d(enum ggml_op_pool pool_type = GGML_OP_POOL_AVG, + ggml_type type_input = GGML_TYPE_F32, + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + int k0 = 3, int k1 = 3, + int s0 = 1, int s1 = 1, + int p0 = 1, int p1 = 1) + : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1); + return out; + } +}; + // GGML_OP_IM2COL struct test_im2col : public test_case { const ggml_type type_input; @@ -1502,6 +1536,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_AVG)); + test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_MAX)); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));