From ee1b78c0911476d80517d76e602820795649d746 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Tue, 13 Aug 2024 17:41:34 +0800 Subject: [PATCH] llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1fd91fcd7..67889111a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9501,7 +9501,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float)); // ggml_group_norm considers groups in the third dimension. - cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens); + cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens); cur = ggml_group_norm(ctx, cur, head_count, 64e-5f); // Convert back to a regular vector. cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);