llama: rwkv6: Fix group_norm assertion failure with Metal
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
683d70cb68
commit
ee1b78c091
1 changed files with 1 additions and 1 deletions
|
@ -9501,7 +9501,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
|
|||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));
|
||||
|
||||
// ggml_group_norm considers groups in the third dimension.
|
||||
cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens);
|
||||
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
|
||||
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
|
||||
// Convert back to a regular vector.
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue