From e94778ade0c7224a15989563adf2ed4a7a046a8c Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Sun, 25 Aug 2024 12:36:29 +0800 Subject: [PATCH] llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade --- src/llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 8b2b920d2..68d53672d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9498,10 +9498,10 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float)); - // ggml_group_norm considers groups in the third dimension. - cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens); - cur = ggml_group_norm(ctx, cur, head_count, 64e-5f); - // Convert back to a regular vector. + // group norm with head_count groups + cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens); + cur = ggml_norm(ctx, cur, 64e-5f); + // Convert back to regular vectors. cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);