ggml : add epsilon as a parameter for group_norm (#8818)
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
cdd1889de6
commit
2d5dd7bb3f
7 changed files with 38 additions and 24 deletions
|
@ -1511,6 +1511,7 @@ struct test_group_norm : public test_case {
|
|||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int32_t num_groups;
|
||||
const float eps;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, num_groups);
|
||||
|
@ -1518,12 +1519,13 @@ struct test_group_norm : public test_case {
|
|||
|
||||
test_group_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
int32_t num_groups = 32)
|
||||
: type(type), ne(ne), num_groups(num_groups) {}
|
||||
int32_t num_groups = 32,
|
||||
float eps = 1e-6f)
|
||||
: type(type), ne(ne), num_groups(num_groups), eps(eps) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
|
||||
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue