diff --git a/ggml-metal.metal b/ggml-metal.metal index 342fa3707..335b990d9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3379,12 +3379,12 @@ kernel void kernel_concat( int64_t o[4] = {0, 0, 0, 0}; - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + if (dim > 0 && i1 < ne01 && i2 < ne02 && i3 < ne03) { src = src0; o[dim] = 0; } else { src = src1; - o[dim] = ne00; + o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03); } for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 876f7329a..b200ccccd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1260,18 +1260,18 @@ struct test_im2col : public test_case { struct test_concat : public test_case { const ggml_type type; const std::array ne_a; - const int dim; const int64_t ne_b_d; + const int dim; std::string vars() override { - return VARS_TO_STR4(type, ne_a, dim, ne_b_d); + return VARS_TO_STR4(type, ne_a, ne_b_d, dim); } test_concat(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 10, 10, 10}, - int dim = 2, - int64_t ne_b_d = 10) - : type(type), ne_a(ne_a), dim(dim), ne_b_d(ne_b_d) {} + int64_t ne_b_d = 10, + int dim = 2) + : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {} ggml_tensor * build_graph(ggml_context * ctx) override { auto ne_b = ne_a; @@ -2216,8 +2216,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } for (int dim : { 0, 1, 2, 3, }) { - test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {10, 10, 10, 10}, dim)); - test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {10, 10, 10, 10}, dim)); + test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim)); } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {