rwkv speed enhancements (batch processing), fixed a rwkv token processing bug
This commit is contained in:
parent
860fb026df
commit
9b6c35b651
2 changed files with 45 additions and 35 deletions
|
@ -409,7 +409,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
n_batch = 1;
|
||||
std::string word;
|
||||
read_rwkv_vocab();
|
||||
int vocabsiz = rwkv_vocab.size();
|
||||
|
@ -425,6 +424,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
n_batch = 1;
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
|
@ -453,6 +453,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else
|
||||
{
|
||||
n_batch = 10; //use sequence mode to speedup
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
|
@ -946,6 +947,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||
{
|
||||
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||
current_context_tokens.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1014,8 +1016,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(embd.size()>1)
|
||||
{
|
||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue