metal : add more general support for ggml_get_rows + tests
This commit is contained in:
parent
9064b1ca05
commit
2cbcba829f
4 changed files with 78 additions and 25 deletions
|
@ -488,17 +488,18 @@ struct test_get_rows : public test_case {
|
|||
const int n; // cols
|
||||
const int m; // rows
|
||||
const int r; // rows to get
|
||||
const int b; // batch size
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, n, m, r);
|
||||
}
|
||||
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
|
||||
: type(type), n(n), m(m), r(r) {}
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
|
||||
: type(type), n(n), m(m), r(r), b(b) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
|
||||
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
|
||||
ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
|
||||
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
|
||||
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
|
||||
return out;
|
||||
}
|
||||
|
@ -507,11 +508,11 @@ struct test_get_rows : public test_case {
|
|||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// rows
|
||||
std::vector<int> data(r);
|
||||
for (int i = 0; i < r; i++) {
|
||||
std::vector<int> data(r*b);
|
||||
for (int i = 0; i < r*b; i++) {
|
||||
data[i] = rand() % m;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
|
||||
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
|
@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
}
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
|
||||
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
|
||||
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
|
||||
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue