ggml : add GGML_PAD_REFLECT_1D
operation (ggml/1034)
* ggml_pad_reflect_1d defined in header * implemented on CPU * called the forward pass * impl Metal kernel * added Metal kernel * added OP_PAD_REFLECT_1D in test-backend-ops.cpp * add test-pad-reflect-1d test case * test case support multiple backend
This commit is contained in:
parent
d405804be8
commit
c2082d93a8
6 changed files with 192 additions and 2 deletions
|
@ -2697,6 +2697,33 @@ struct test_pad : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_PAD_REFLECT_1D
|
||||
struct test_pad_reflect_1d : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const int pad_0;
|
||||
const int pad_1;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
||||
}
|
||||
|
||||
test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
|
||||
int pad_0 = 10, int pad_1 = 9)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARANGE
|
||||
struct test_arange : public test_case {
|
||||
const ggml_type type;
|
||||
|
@ -3816,6 +3843,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue