llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-13 17:41:34 +08:00
parent 683d70cb68
commit ee1b78c091

View file

@ -9501,7 +9501,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));
// ggml_group_norm considers groups in the third dimension.
cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens);
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
// Convert back to a regular vector.
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);