Add Vulkan GROUP_NORM eps parameter

This commit is contained in:
0cc4m 2024-08-10 11:19:30 +02:00
parent 9e0ac9895c
commit 61d8388721

View file

@ -4621,11 +4621,12 @@ static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, cons
} }
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
int * op_params = (int *)dst->op_params; const int * int_op_params = (const int *)dst->op_params;
const float * float_op_params = (const float *)dst->op_params;
uint32_t num_groups = op_params[0]; const uint32_t num_groups = int_op_params[0];
uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); const float eps = float_op_params[1];
static const float eps = 1e-6f; const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
} }
@ -6988,7 +6989,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_NORM) { } else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) { } else if (tensor->op == GGML_OP_GROUP_NORM) {
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params); tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_RMS_NORM) { } else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) { } else if (tensor->op == GGML_OP_SOFT_MAX) {