rwkv: skip computing output for unused tokens for hybrid models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-31 15:48:36 +08:00
parent cffd099aad
commit 9cad1ca194

View file

@ -7760,7 +7760,18 @@ struct llm_build_context {
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1
);
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
struct ggml_tensor * inp_ffn = x_norm_ffn;
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
ggml_build_forward_expand(gf, cur);
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@ -7789,9 +7800,8 @@ struct llm_build_context {
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// struct ggml_tensor * inp_out_ids = build_inp_out_ids();
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
@ -7874,6 +7884,13 @@ struct llm_build_context {
cb(ffn_inp, "ffn_inp", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
@ -7897,10 +7914,6 @@ struct llm_build_context {
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
@ -7981,7 +7994,18 @@ struct llm_build_context {
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1
);
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
struct ggml_tensor * inp_ffn = x_norm_ffn;
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
ggml_build_forward_expand(gf, cur);
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@ -8010,10 +8034,6 @@ struct llm_build_context {
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
@ -8095,6 +8115,13 @@ struct llm_build_context {
cb(ffn_inp, "ffn_inp", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
@ -8118,10 +8145,6 @@ struct llm_build_context {
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);