diff --git a/src/llama.cpp b/src/llama.cpp index 41fcd4cdd..5233ff82a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7760,7 +7760,18 @@ struct llm_build_context { ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), 1 ); - cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev)); + + struct ggml_tensor * inp_ffn = x_norm_ffn; + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids); + x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev)); ggml_build_forward_expand(gf, cur); struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); @@ -7789,9 +7800,8 @@ struct llm_build_context { } cur = inpL; - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + // cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); cb(cur, "result_norm", -1); @@ -7874,6 +7884,13 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + // feed-forward network cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, @@ -7897,10 +7914,6 @@ struct llm_build_context { } cur = inpL; - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); @@ -7981,7 +7994,18 @@ struct llm_build_context { ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), 1 ); - cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev)); + + struct ggml_tensor * inp_ffn = x_norm_ffn; + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids); + x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev)); ggml_build_forward_expand(gf, cur); struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); @@ -8010,10 +8034,6 @@ struct llm_build_context { } cur = inpL; - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); cb(cur, "result_norm", -1); @@ -8095,6 +8115,13 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + // feed-forward network cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, @@ -8118,10 +8145,6 @@ struct llm_build_context { } cur = inpL; - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1);