llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-25 15:48:35 +08:00
parent e94778ade0
commit 7756afd8dd
2 changed files with 37 additions and 47 deletions

View file

@ -299,6 +299,7 @@ class Model:
gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D, gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
) )
) )
or not name.endswith(".weight") or not name.endswith(".weight")
@ -2764,6 +2765,7 @@ class RwkvModel(Model):
self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_file_type(self.ftype)
# required by llama.cpp, unused # required by llama.cpp, unused
self.gguf_writer.add_head_count(0) self.gguf_writer.add_head_count(0)

View file

@ -5161,6 +5161,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_1B: return "1B"; case MODEL_1B: return "1B";
case MODEL_1_3B: return "1.3B"; case MODEL_1_3B: return "1.3B";
case MODEL_1_4B: return "1.4B"; case MODEL_1_4B: return "1.4B";
case MODEL_1_6B: return "1.6B";
case MODEL_2B: return "2B"; case MODEL_2B: return "2B";
case MODEL_2_8B: return "2.8B"; case MODEL_2_8B: return "2.8B";
case MODEL_3B: return "3B"; case MODEL_3B: return "3B";
@ -15066,49 +15067,40 @@ struct llm_build_context {
GGML_ASSERT(batch.equal_seqs); GGML_ASSERT(batch.equal_seqs);
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * cur;
struct ggml_tensor * inpL;
struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_mask = build_inp_s_mask();
ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
for (int layer_i = 0; layer_i < n_layer; ++layer_i) { for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[layer_i]; const llama_layer * layer = &model.layers[il];
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
gf, kv_self.k_l[layer_i], state_copy, state_mask, gf, kv_self.k_l[il], state_copy, state_mask,
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
gf, kv_self.v_l[layer_i], state_copy, state_mask, gf, kv_self.v_l[il], state_copy, state_mask,
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
cur = ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
token_shift = ggml_cont( struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ctx0, struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
ggml_permute(
ctx0,
ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs),
0, 2, 1, 3
)
);
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0); struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift));
att_shift = ggml_reshape_3d(ctx0, att_shift, n_embd, 1, n_seqs);
ffn_shift = ggml_reshape_3d(ctx0, ffn_shift, n_embd, 1, n_seqs);
struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * x_prev = ggml_concat( struct ggml_tensor * x_prev = ggml_concat(
ctx0, ctx0,
att_shift, att_shift,
ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0), ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
1 1
); );
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states)); cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
@ -15117,38 +15109,22 @@ struct llm_build_context {
wkv_states, wkv_states,
ggml_view_1d( ggml_view_1d(
ctx0, ctx0,
kv_self.v_l[layer_i], kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs, hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type) hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
) )
) )
); );
struct ggml_tensor * last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0, last_norm,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0)
)
);
x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
x_prev = ggml_concat( x_prev = ggml_concat(
ctx0, ctx0,
ffn_shift, ffn_shift,
ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0), ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1 1
); );
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev)); cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev));
last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0, last_norm,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift))
)
);
token_shift = ggml_cont( token_shift = ggml_cont(
ctx0, ctx0,
@ -15159,20 +15135,32 @@ struct llm_build_context {
) )
); );
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
ggml_cpy( ggml_cpy(
ctx0, ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0), ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type)) ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
) )
); );
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) { if ((il + 1) % hparams.rescale_every_n_layers == 0) {
cur = ggml_scale(ctx0, cur, 0.5F); cur = ggml_scale(ctx0, cur, 0.5F);
} }
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
} }
cur = inpL;
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);