diff --git a/src/llama.cpp b/src/llama.cpp index 1fd91fcd7..67889111a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9501,7 +9501,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float)); // ggml_group_norm considers groups in the third dimension. - cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens); + cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens); cur = ggml_group_norm(ctx, cur, head_count, 64e-5f); // Convert back to a regular vector. cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);