diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2cc..cbe368c2d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1659,6 +1659,17 @@ extern "C" { struct ggml_tensor * b, int stride); + GGML_API struct ggml_tensor * ggml_conv_transpose_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); + enum ggml_op_pool { GGML_OP_POOL_MAX, GGML_OP_POOL_AVG, diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index d55673b58..b745f7d0e 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -3895,7 +3895,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens switch (tensor->op) { case GGML_OP_CONV_TRANSPOSE_2D: - func = ggml_sycl_op_conv_2d; + func = ggml_sycl_op_conv_transpose_2d; break; case GGML_OP_CONV_TRANSPOSE_1D: func = ggml_sycl_op_conv_transpose_1d; diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp index 6b328e1a8..e1f57a6b8 100644 --- a/ggml/src/ggml-sycl/conv.cpp +++ b/ggml/src/ggml-sycl/conv.cpp @@ -99,7 +99,7 @@ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_ } -void ggml_sycl_op_conv_2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, +void ggml_sycl_op_conv_transpose_2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) { const void * src0_d = (const void *)src0->data; const void * src1_d = (const void *)src1->data; diff --git a/ggml/src/ggml-sycl/conv.hpp b/ggml/src/ggml-sycl/conv.hpp index 0dc8f9906..b61b14835 100644 --- a/ggml/src/ggml-sycl/conv.hpp +++ b/ggml/src/ggml-sycl/conv.hpp @@ -18,7 +18,7 @@ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst); -void ggml_sycl_op_conv_2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, +void ggml_sycl_op_conv_transpose_2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst); #endif // GGML_SYCL_CONV_HPP diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 526e0fe1a..4ab48a57d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6770,35 +6770,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { -#ifdef GGML_SYCL_DNNL - bool is_node = false; - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - - const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - const int64_t ne[4] = { - OW, - OH, - a->ne[3], // OC - b->ne[3], // N - }; - - struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1}; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_TRANSPOSE_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -#else struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = @@ -6811,7 +6783,6 @@ struct ggml_tensor * ggml_conv_2d( return result; -#endif } // ggml_conv_2d_sk_p0 @@ -6837,6 +6808,43 @@ static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int return (ins - 1) * s - 2 * p + ks; } +struct ggml_tensor * ggml_conv_transpose_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + GGML_ASSERT(a->ne[3] == b->ne[2]); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), + ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + a->ne[2], b->ne[3], + }; + + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7bc3d3bbc..5f8f2ab42 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1337,6 +1337,35 @@ struct test_conv_2d : public test_case { } }; +struct test_conv_transpose_2d : public test_case { + const std::array ne_input; + const std::array ne_kernel; + + const int s0; // stride + const int p0; // padding + const int d0; // dilation + const int s1; // stride + const int p1; // padding + const int d1; // dilation + + std::string vars() override { + return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0); + } + + test_conv_transpose_2d(std::array ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1] + int s0 = 1, int p0 = 0, int d0 = 1, + int s1 = 1, int p1 = 0, int d1 = 1) + : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0), s1(s1), p1(p1), d1(d1){} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data()); + ggml_tensor * out = ggml_conv_transpose_2d(ctx, kernel, input, s0, s1, p0, p1, d0, d1); + return out; + } +}; + // GGML_OP_IM2COL struct test_im2col : public test_case { const ggml_type type_input; @@ -2189,7 +2218,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_2d()); + test_cases.emplace_back(new test_conv_transpose_2d()); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));