ggml : add epsilon as a parameter for group_norm (#8818)
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
cdd1889de6
commit
2d5dd7bb3f
7 changed files with 38 additions and 24 deletions
|
@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
}
|
||||
|
||||
static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
const int num_groups, const int group_size,
|
||||
const int num_groups, const float eps, const int group_size,
|
||||
const int ne_elements, queue_ptr stream, int device) {
|
||||
static const float eps = 1e-6f;
|
||||
if (group_size < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
|
@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
int num_groups = dst->op_params[0];
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
||||
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue