rwkv speed enhancements (batch processing), fixed a rwkv token processing bug
This commit is contained in:
parent
860fb026df
commit
9b6c35b651
2 changed files with 45 additions and 35 deletions
|
@ -409,7 +409,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
n_batch = 1;
|
|
||||||
std::string word;
|
std::string word;
|
||||||
read_rwkv_vocab();
|
read_rwkv_vocab();
|
||||||
int vocabsiz = rwkv_vocab.size();
|
int vocabsiz = rwkv_vocab.size();
|
||||||
|
@ -425,6 +424,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
|
|
||||||
if (file_format == FileFormat::RWKV_1)
|
if (file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
|
n_batch = 1;
|
||||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
|
||||||
//setup buffers for rwkv state
|
//setup buffers for rwkv state
|
||||||
|
@ -453,6 +453,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
n_batch = 10; //use sequence mode to speedup
|
||||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
|
||||||
//setup buffers for rwkv state
|
//setup buffers for rwkv state
|
||||||
|
@ -946,6 +947,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||||
{
|
{
|
||||||
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||||
|
current_context_tokens.pop_back();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1014,8 +1016,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
{
|
||||||
|
if(embd.size()>1)
|
||||||
|
{
|
||||||
|
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
|
}
|
||||||
|
|
||||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue