Add Vulkan GROUP_NORM eps parameter
This commit is contained in:
parent
9e0ac9895c
commit
61d8388721
1 changed files with 6 additions and 5 deletions
|
@ -4621,11 +4621,12 @@ static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
int * op_params = (int *)dst->op_params;
|
const int * int_op_params = (const int *)dst->op_params;
|
||||||
|
const float * float_op_params = (const float *)dst->op_params;
|
||||||
|
|
||||||
uint32_t num_groups = op_params[0];
|
const uint32_t num_groups = int_op_params[0];
|
||||||
uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
const float eps = float_op_params[1];
|
||||||
static const float eps = 1e-6f;
|
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
@ -6988,7 +6989,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
} else if (tensor->op == GGML_OP_NORM) {
|
} else if (tensor->op == GGML_OP_NORM) {
|
||||||
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
||||||
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
||||||
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params);
|
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
||||||
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
||||||
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
||||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue