rwkv speed enhancements (batch processing), fixed a rwkv token processing bug

This commit is contained in:
Concedo 2023-06-13 16:02:12 +08:00
parent 860fb026df
commit 9b6c35b651
2 changed files with 45 additions and 35 deletions

View file

@ -409,7 +409,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_batch = 1;
std::string word;
read_rwkv_vocab();
int vocabsiz = rwkv_vocab.size();
@ -425,6 +424,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
if (file_format == FileFormat::RWKV_1)
{
n_batch = 1;
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
@ -453,6 +453,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else
{
n_batch = 10; //use sequence mode to speedup
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
@ -946,6 +947,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(embd_inp.size()==0 && current_context_tokens.size()>0)
{
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
current_context_tokens.pop_back();
}
}
}
@ -1015,7 +1017,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else
{
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
if(embd.size()>1)
{
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
else
{
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
}

View file

@ -22,7 +22,7 @@ void timer_start()
double timer_check()
{
auto endtime = std::chrono::high_resolution_clock().now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endtime - bench_timer);
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endtime - bench_timer);
double time_taken = duration.count()/1000.0;
return time_taken;
}
@ -37,8 +37,8 @@ void print_vec(std::vector<std::string> &embd)
{
std::cout << ',';
}
first = false;
std::cout << i;
first = false;
std::cout << i;
}
std::cout << "]\n";
}
@ -52,8 +52,8 @@ void print_tok_vec(std::vector<int> &embd)
{
std::cout << ',';
}
first = false;
std::cout << i;
first = false;
std::cout << i;
}
std::cout << "]\n";
}
@ -78,7 +78,7 @@ void print_tok_vec(std::vector<float> &embd)
std::cout << "]\n";
}
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
FileFormat check_file_format(const std::string & fname)
{
std::vector<char> f_buf(1024*1024);
@ -140,9 +140,9 @@ void print_tok_vec(std::vector<float> &embd)
fin.read((char *)&temp, sizeof(temp)); //n_layer
fin.read((char *)&temp, sizeof(temp)); //f16
const int32_t qntvr = temp / 1000;
temp %= 1000;
temp %= 1000;
if (qntvr != 0)
{
{
if (qntvr == 1)
{
fileformat = FileFormat::GPT2_3;
@ -168,7 +168,7 @@ void print_tok_vec(std::vector<float> &embd)
fin.read((char *)&temp, sizeof(temp)); //n_layer
fin.read((char *)&temp, sizeof(temp)); //n_rot
fin.read((char *)&temp, sizeof(temp)); //either par_res or ftype (for older ver)
if(temp!=0 && temp!=1){
//must be ftype, means its an older model. par_res will be undefined
fileformat = FileFormat::NEOX_2;
@ -182,7 +182,7 @@ void print_tok_vec(std::vector<float> &embd)
bool isNewFtype = (temp2>=1000 && temp2<=9000 && temp2%1000<20);
if(!isNewFtype)
{
{
fileformat = FileFormat::NEOX_2;
if((temp==0||temp==1)&&(temp2==0||temp2==1))//special case: par_res and ftype are both 1 or 0
{
@ -193,7 +193,7 @@ void print_tok_vec(std::vector<float> &embd)
else
{
const int32_t qntvr = temp2 / 1000; //for future use
//then temp was par_res, use_parallel_residual is false in RedPajama
//then temp was par_res, use_parallel_residual is false in RedPajama
if(qntvr==1)
{
fileformat = (temp==0?FileFormat::NEOX_5:FileFormat::NEOX_4);
@ -201,10 +201,10 @@ void print_tok_vec(std::vector<float> &embd)
else
{
fileformat = (temp==0?FileFormat::NEOX_7:FileFormat::NEOX_6);
}
}
}
}
}
}
else if(magic == 0x67676d66) //v2 format ggmf
@ -244,7 +244,7 @@ void print_tok_vec(std::vector<float> &embd)
}
}
fin.close();
return fileformat;
}
@ -296,9 +296,9 @@ void print_tok_vec(std::vector<float> &embd)
std::vector<int> LongestCommonSubseq(const std::vector<int> x, const std::vector<int> y)
{
int m = x.size(), n = y.size();
//int LCSuff[m+1][n+1];
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
for (int j = 0; j <= n; j++)
LCSuff[0][j] = 0;
@ -326,7 +326,7 @@ void print_tok_vec(std::vector<float> &embd)
auto off1 = ((i - LCSuff[i][j] + 1) - 1);
auto off2 = off1 + LCSuff[i][j];
longest.clear();
// std::vector<int>().swap(longest);
// std::vector<int>().swap(longest);
longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
// x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
}
@ -336,9 +336,9 @@ void print_tok_vec(std::vector<float> &embd)
}
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
bool useSmartContext, const bool requireFullSubset)
{
{
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context
@ -349,7 +349,7 @@ void print_tok_vec(std::vector<float> &embd)
//fast forward the past based on identical tokens, stop once a divergence is noted
int embd_inp_len = embd_inp.size();
bool fastforwardok = true;
for (int i = 0; i < current_context_tokens.size(); ++i)
{
if (current_context_tokens[i] == embd_inp[i])
@ -359,11 +359,11 @@ void print_tok_vec(std::vector<float> &embd)
}
else
{
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
{
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
n_past = 0;
fastforwardok = false;
n_past = 0;
fastforwardok = false;
}
break;
}
@ -400,10 +400,10 @@ void print_tok_vec(std::vector<float> &embd)
if (fastforwardok && useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold)
{
//see if smartcontext is still usable
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
//see if smartcontext is still usable
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common
{
{
int found = ArrFindIndexOf(embd_inp,shared);
if(found>=0)
{
@ -418,7 +418,7 @@ void print_tok_vec(std::vector<float> &embd)
{
offset_fix = 0;
}
for (int i = n_past; i < current_context_tokens.size(); ++i)
{
if (current_context_tokens[i] == embd_inp[i-offset_fix])
@ -438,7 +438,7 @@ void print_tok_vec(std::vector<float> &embd)
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
}else{
smartcontext.clear();
}
@ -453,16 +453,16 @@ void print_tok_vec(std::vector<float> &embd)
smartcontext.clear();
}
if(fastforwardok && useSmartContext
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
&& embd_inp_len >= SCInpLenThreshold
if(fastforwardok && useSmartContext
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
&& embd_inp_len >= SCInpLenThreshold
&& current_context_tokens.size() - n_past > SCPastLenThreshold)
{
{
//determine longest common substring after removing start part
int shiftamt = embd_inp.size() * SCTruncationRatio;
smartcontext = std::vector<int>(embd_inp.begin() + shiftamt, embd_inp.end());
printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt);
embd_inp = smartcontext;
//if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker.
//when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS