rwkv speed enhancements (batch processing), fixed a rwkv token processing bug
This commit is contained in:
parent
860fb026df
commit
9b6c35b651
2 changed files with 45 additions and 35 deletions
|
@ -409,7 +409,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
n_batch = 1;
|
||||
std::string word;
|
||||
read_rwkv_vocab();
|
||||
int vocabsiz = rwkv_vocab.size();
|
||||
|
@ -425,6 +424,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
n_batch = 1;
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
|
@ -453,6 +453,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else
|
||||
{
|
||||
n_batch = 10; //use sequence mode to speedup
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
|
@ -946,6 +947,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||
{
|
||||
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||
current_context_tokens.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1015,7 +1017,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
if(embd.size()>1)
|
||||
{
|
||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ void timer_start()
|
|||
double timer_check()
|
||||
{
|
||||
auto endtime = std::chrono::high_resolution_clock().now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endtime - bench_timer);
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endtime - bench_timer);
|
||||
double time_taken = duration.count()/1000.0;
|
||||
return time_taken;
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ void print_vec(std::vector<std::string> &embd)
|
|||
{
|
||||
std::cout << ',';
|
||||
}
|
||||
first = false;
|
||||
std::cout << i;
|
||||
first = false;
|
||||
std::cout << i;
|
||||
}
|
||||
std::cout << "]\n";
|
||||
}
|
||||
|
@ -52,8 +52,8 @@ void print_tok_vec(std::vector<int> &embd)
|
|||
{
|
||||
std::cout << ',';
|
||||
}
|
||||
first = false;
|
||||
std::cout << i;
|
||||
first = false;
|
||||
std::cout << i;
|
||||
}
|
||||
std::cout << "]\n";
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
std::cout << "]\n";
|
||||
}
|
||||
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
FileFormat check_file_format(const std::string & fname)
|
||||
{
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
|
@ -140,9 +140,9 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
fin.read((char *)&temp, sizeof(temp)); //n_layer
|
||||
fin.read((char *)&temp, sizeof(temp)); //f16
|
||||
const int32_t qntvr = temp / 1000;
|
||||
temp %= 1000;
|
||||
temp %= 1000;
|
||||
if (qntvr != 0)
|
||||
{
|
||||
{
|
||||
if (qntvr == 1)
|
||||
{
|
||||
fileformat = FileFormat::GPT2_3;
|
||||
|
@ -168,7 +168,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
fin.read((char *)&temp, sizeof(temp)); //n_layer
|
||||
fin.read((char *)&temp, sizeof(temp)); //n_rot
|
||||
fin.read((char *)&temp, sizeof(temp)); //either par_res or ftype (for older ver)
|
||||
|
||||
|
||||
if(temp!=0 && temp!=1){
|
||||
//must be ftype, means its an older model. par_res will be undefined
|
||||
fileformat = FileFormat::NEOX_2;
|
||||
|
@ -182,7 +182,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
bool isNewFtype = (temp2>=1000 && temp2<=9000 && temp2%1000<20);
|
||||
|
||||
if(!isNewFtype)
|
||||
{
|
||||
{
|
||||
fileformat = FileFormat::NEOX_2;
|
||||
if((temp==0||temp==1)&&(temp2==0||temp2==1))//special case: par_res and ftype are both 1 or 0
|
||||
{
|
||||
|
@ -193,7 +193,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
else
|
||||
{
|
||||
const int32_t qntvr = temp2 / 1000; //for future use
|
||||
//then temp was par_res, use_parallel_residual is false in RedPajama
|
||||
//then temp was par_res, use_parallel_residual is false in RedPajama
|
||||
if(qntvr==1)
|
||||
{
|
||||
fileformat = (temp==0?FileFormat::NEOX_5:FileFormat::NEOX_4);
|
||||
|
@ -201,10 +201,10 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
else
|
||||
{
|
||||
fileformat = (temp==0?FileFormat::NEOX_7:FileFormat::NEOX_6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if(magic == 0x67676d66) //v2 format ggmf
|
||||
|
@ -244,7 +244,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
}
|
||||
}
|
||||
fin.close();
|
||||
|
||||
|
||||
return fileformat;
|
||||
}
|
||||
|
||||
|
@ -296,9 +296,9 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
std::vector<int> LongestCommonSubseq(const std::vector<int> x, const std::vector<int> y)
|
||||
{
|
||||
int m = x.size(), n = y.size();
|
||||
|
||||
|
||||
//int LCSuff[m+1][n+1];
|
||||
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
|
||||
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
|
||||
|
||||
for (int j = 0; j <= n; j++)
|
||||
LCSuff[0][j] = 0;
|
||||
|
@ -326,7 +326,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
auto off1 = ((i - LCSuff[i][j] + 1) - 1);
|
||||
auto off2 = off1 + LCSuff[i][j];
|
||||
longest.clear();
|
||||
// std::vector<int>().swap(longest);
|
||||
// std::vector<int>().swap(longest);
|
||||
longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
|
||||
// x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
|
||||
}
|
||||
|
@ -336,9 +336,9 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
}
|
||||
|
||||
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
|
||||
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
|
||||
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
|
||||
bool useSmartContext, const bool requireFullSubset)
|
||||
{
|
||||
{
|
||||
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
|
||||
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
|
||||
const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context
|
||||
|
@ -349,7 +349,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
//fast forward the past based on identical tokens, stop once a divergence is noted
|
||||
int embd_inp_len = embd_inp.size();
|
||||
bool fastforwardok = true;
|
||||
|
||||
|
||||
for (int i = 0; i < current_context_tokens.size(); ++i)
|
||||
{
|
||||
if (current_context_tokens[i] == embd_inp[i])
|
||||
|
@ -359,11 +359,11 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
}
|
||||
else
|
||||
{
|
||||
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
|
||||
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
|
||||
n_past = 0;
|
||||
fastforwardok = false;
|
||||
n_past = 0;
|
||||
fastforwardok = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -400,10 +400,10 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
|
||||
if (fastforwardok && useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold)
|
||||
{
|
||||
//see if smartcontext is still usable
|
||||
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
|
||||
//see if smartcontext is still usable
|
||||
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
|
||||
if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common
|
||||
{
|
||||
{
|
||||
int found = ArrFindIndexOf(embd_inp,shared);
|
||||
if(found>=0)
|
||||
{
|
||||
|
@ -418,7 +418,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
{
|
||||
offset_fix = 0;
|
||||
}
|
||||
|
||||
|
||||
for (int i = n_past; i < current_context_tokens.size(); ++i)
|
||||
{
|
||||
if (current_context_tokens[i] == embd_inp[i-offset_fix])
|
||||
|
@ -438,7 +438,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
|
||||
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
|
||||
|
||||
|
||||
}else{
|
||||
smartcontext.clear();
|
||||
}
|
||||
|
@ -453,16 +453,16 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
smartcontext.clear();
|
||||
}
|
||||
|
||||
if(fastforwardok && useSmartContext
|
||||
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
|
||||
&& embd_inp_len >= SCInpLenThreshold
|
||||
if(fastforwardok && useSmartContext
|
||||
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
|
||||
&& embd_inp_len >= SCInpLenThreshold
|
||||
&& current_context_tokens.size() - n_past > SCPastLenThreshold)
|
||||
{
|
||||
{
|
||||
//determine longest common substring after removing start part
|
||||
int shiftamt = embd_inp.size() * SCTruncationRatio;
|
||||
smartcontext = std::vector<int>(embd_inp.begin() + shiftamt, embd_inp.end());
|
||||
printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt);
|
||||
|
||||
|
||||
embd_inp = smartcontext;
|
||||
//if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker.
|
||||
//when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue