ggml : add epsilon as a parameter for group_norm (#8818)

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-06 15:26:46 +08:00 committed by GitHub
parent cdd1889de6
commit 2d5dd7bb3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 38 additions and 24 deletions

View file

@ -5374,6 +5374,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups,
float eps,
bool inplace) {
bool is_node = false;
@ -5384,7 +5385,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op_params[0] = n_groups;
ggml_set_op_params_i32(result, 0, n_groups);
ggml_set_op_params_f32(result, 1, eps);
result->op = GGML_OP_GROUP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -5396,15 +5398,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
struct ggml_tensor * ggml_group_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups) {
return ggml_group_norm_impl(ctx, a, n_groups, false);
int n_groups,
float eps) {
return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
}
struct ggml_tensor * ggml_group_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups) {
return ggml_group_norm_impl(ctx, a, n_groups, true);
int n_groups,
float eps) {
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
}
// ggml_mul_mat
@ -12095,10 +12099,11 @@ static void ggml_compute_forward_group_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS
const float eps = 1e-6f; // TODO: make this a parameter
// TODO: optimize
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
int n_channels = src0->ne[2];
int n_groups = dst->op_params[0];
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;