Add logits conversion to rwkv5

This commit is contained in:
Layl Bongers 2024-04-23 11:12:09 +02:00 committed by Molly Sophia
parent a866789603
commit 4e23d9715b

View file

@ -1345,6 +1345,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
LLM_ARCH_RWKV, LLM_ARCH_RWKV,
{ {
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
}, },
}, },
@ -5801,6 +5803,8 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_RWKV: case LLM_ARCH_RWKV:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
// TODO: Re-using mamba keys right now, but RWKV isn't state-space // TODO: Re-using mamba keys right now, but RWKV isn't state-space
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
@ -8234,6 +8238,13 @@ static bool llm_load_tensors(
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
@ -14734,6 +14745,14 @@ struct llm_build_context {
// Dummy operation, just to copy, we're not doing anything with it right now // Dummy operation, just to copy, we're not doing anything with it right now
ggml_tensor *output = ggml_scale(ctx0, input_embeddings, 1.0); ggml_tensor *output = ggml_scale(ctx0, input_embeddings, 1.0);
// Something related to skipping tokens, specifics unclear
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
output = ggml_get_rows(ctx0, output, inp_out_ids);
// Output head, convert result vector to logits
output = llm_build_norm(ctx0, output, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
output = ggml_mul_mat(ctx0, model.output, output);
// Mark the output as being the result // Mark the output as being the result
cb(output, "result_output", -1); cb(output, "result_output", -1);
ggml_build_forward_expand(gf, output); ggml_build_forward_expand(gf, output);