From 61d838872156aa49cb4c9a954fc50396c0f0391b Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 10 Aug 2024 11:19:30 +0200 Subject: [PATCH] Add Vulkan GROUP_NORM eps parameter --- ggml/src/ggml-vulkan.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 5a484c7b3..27861e04f 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -4621,11 +4621,12 @@ static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - int * op_params = (int *)dst->op_params; + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; - uint32_t num_groups = op_params[0]; - uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - static const float eps = 1e-6f; + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } @@ -6988,7 +6989,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params); + tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) {