cuda : non-cont concat support (#7610)
* tests : add non-cont concat tests * cuda : non-cont concat support ggml-ci
This commit is contained in:
parent
210d99173d
commit
cce3dcffc5
2 changed files with 112 additions and 29 deletions
|
@ -1262,22 +1262,37 @@ struct test_concat : public test_case {
|
|||
const std::array<int64_t, 4> ne_a;
|
||||
const int64_t ne_b_d;
|
||||
const int dim;
|
||||
const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
|
||||
return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
|
||||
}
|
||||
|
||||
test_concat(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
|
||||
int64_t ne_b_d = 10,
|
||||
int dim = 2)
|
||||
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
|
||||
int dim = 2, int v = 0)
|
||||
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
auto ne_b = ne_a;
|
||||
ne_b[dim] = ne_b_d;
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||
ggml_tensor * a;
|
||||
if (v & 1) {
|
||||
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
|
||||
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
} else {
|
||||
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
}
|
||||
ggml_tensor * b;
|
||||
if (v & 2) {
|
||||
auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
|
||||
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
||||
} else {
|
||||
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||
}
|
||||
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
|
||||
return out;
|
||||
}
|
||||
|
@ -2215,9 +2230,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
}
|
||||
}
|
||||
|
||||
for (int dim : { 0, 1, 2, 3, }) {
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
|
||||
for (int v : { 0, 1, 2, 3 }) {
|
||||
for (int dim : { 0, 1, 2, 3, }) {
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
|
||||
}
|
||||
}
|
||||
|
||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue