llama: rwkv6: Use `ggml_norm
instead of
ggml_group_norm
`
Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
57decb4a38
commit
e94778ade0
1 changed files with 4 additions and 4 deletions
|
@ -9498,10 +9498,10 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
|
|||
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
|
||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
|
||||
|
||||
// ggml_group_norm considers groups in the third dimension.
|
||||
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
|
||||
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
|
||||
// Convert back to a regular vector.
|
||||
// group norm with head_count groups
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
|
||||
cur = ggml_norm(ctx, cur, 64e-5f);
|
||||
// Convert back to regular vectors.
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
|
||||
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue