From 9b6c35b6518b30e32f9ab6e3a5de33abd7b918cf Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 13 Jun 2023 16:02:12 +0800 Subject: [PATCH] rwkv speed enhancements (batch processing), fixed a rwkv token processing bug --- gpttype_adapter.cpp | 14 ++++++++-- model_adapter.cpp | 66 ++++++++++++++++++++++----------------------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 44997ff2f..d9f0a1ae9 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -409,7 +409,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { - n_batch = 1; std::string word; read_rwkv_vocab(); int vocabsiz = rwkv_vocab.size(); @@ -425,6 +424,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in if (file_format == FileFormat::RWKV_1) { + n_batch = 1; rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads); //setup buffers for rwkv state @@ -453,6 +453,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } else { + n_batch = 10; //use sequence mode to speedup rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads); //setup buffers for rwkv state @@ -946,6 +947,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if(embd_inp.size()==0 && current_context_tokens.size()>0) { embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]); + current_context_tokens.pop_back(); } } } @@ -1015,7 +1017,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } else { - evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + if(embd.size()>1) + { + evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + } + else + { + evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + } + memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size()); rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out; } diff --git a/model_adapter.cpp b/model_adapter.cpp index 29d12809b..547a8a1ef 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -22,7 +22,7 @@ void timer_start() double timer_check() { auto endtime = std::chrono::high_resolution_clock().now(); - auto duration = std::chrono::duration_cast(endtime - bench_timer); + auto duration = std::chrono::duration_cast(endtime - bench_timer); double time_taken = duration.count()/1000.0; return time_taken; } @@ -37,8 +37,8 @@ void print_vec(std::vector &embd) { std::cout << ','; } - first = false; - std::cout << i; + first = false; + std::cout << i; } std::cout << "]\n"; } @@ -52,8 +52,8 @@ void print_tok_vec(std::vector &embd) { std::cout << ','; } - first = false; - std::cout << i; + first = false; + std::cout << i; } std::cout << "]\n"; } @@ -78,7 +78,7 @@ void print_tok_vec(std::vector &embd) std::cout << "]\n"; } -//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) +//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) FileFormat check_file_format(const std::string & fname) { std::vector f_buf(1024*1024); @@ -140,9 +140,9 @@ void print_tok_vec(std::vector &embd) fin.read((char *)&temp, sizeof(temp)); //n_layer fin.read((char *)&temp, sizeof(temp)); //f16 const int32_t qntvr = temp / 1000; - temp %= 1000; + temp %= 1000; if (qntvr != 0) - { + { if (qntvr == 1) { fileformat = FileFormat::GPT2_3; @@ -168,7 +168,7 @@ void print_tok_vec(std::vector &embd) fin.read((char *)&temp, sizeof(temp)); //n_layer fin.read((char *)&temp, sizeof(temp)); //n_rot fin.read((char *)&temp, sizeof(temp)); //either par_res or ftype (for older ver) - + if(temp!=0 && temp!=1){ //must be ftype, means its an older model. par_res will be undefined fileformat = FileFormat::NEOX_2; @@ -182,7 +182,7 @@ void print_tok_vec(std::vector &embd) bool isNewFtype = (temp2>=1000 && temp2<=9000 && temp2%1000<20); if(!isNewFtype) - { + { fileformat = FileFormat::NEOX_2; if((temp==0||temp==1)&&(temp2==0||temp2==1))//special case: par_res and ftype are both 1 or 0 { @@ -193,7 +193,7 @@ void print_tok_vec(std::vector &embd) else { const int32_t qntvr = temp2 / 1000; //for future use - //then temp was par_res, use_parallel_residual is false in RedPajama + //then temp was par_res, use_parallel_residual is false in RedPajama if(qntvr==1) { fileformat = (temp==0?FileFormat::NEOX_5:FileFormat::NEOX_4); @@ -201,10 +201,10 @@ void print_tok_vec(std::vector &embd) else { fileformat = (temp==0?FileFormat::NEOX_7:FileFormat::NEOX_6); - } + } } } - + } } else if(magic == 0x67676d66) //v2 format ggmf @@ -244,7 +244,7 @@ void print_tok_vec(std::vector &embd) } } fin.close(); - + return fileformat; } @@ -296,9 +296,9 @@ void print_tok_vec(std::vector &embd) std::vector LongestCommonSubseq(const std::vector x, const std::vector y) { int m = x.size(), n = y.size(); - + //int LCSuff[m+1][n+1]; - std::vector> LCSuff(m+1, std::vector(n+1)); + std::vector> LCSuff(m+1, std::vector(n+1)); for (int j = 0; j <= n; j++) LCSuff[0][j] = 0; @@ -326,7 +326,7 @@ void print_tok_vec(std::vector &embd) auto off1 = ((i - LCSuff[i][j] + 1) - 1); auto off2 = off1 + LCSuff[i][j]; longest.clear(); - // std::vector().swap(longest); + // std::vector().swap(longest); longest = std::vector(x.begin() + off1, x.begin() + off2); // x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]); } @@ -336,9 +336,9 @@ void print_tok_vec(std::vector &embd) } void ContextFastForward(std::vector ¤t_context_tokens, std::vector &embd_inp, - int &n_past, std::vector &last_n_tokens, const int nctx, std::vector &smartcontext, + int &n_past, std::vector &last_n_tokens, const int nctx, std::vector &smartcontext, bool useSmartContext, const bool requireFullSubset) - { + { const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context @@ -349,7 +349,7 @@ void print_tok_vec(std::vector &embd) //fast forward the past based on identical tokens, stop once a divergence is noted int embd_inp_len = embd_inp.size(); bool fastforwardok = true; - + for (int i = 0; i < current_context_tokens.size(); ++i) { if (current_context_tokens[i] == embd_inp[i]) @@ -359,11 +359,11 @@ void print_tok_vec(std::vector &embd) } else { - if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context + if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context { last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end()); - n_past = 0; - fastforwardok = false; + n_past = 0; + fastforwardok = false; } break; } @@ -400,10 +400,10 @@ void print_tok_vec(std::vector &embd) if (fastforwardok && useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold) { - //see if smartcontext is still usable - auto shared = LongestCommonSubseq(smartcontext, embd_inp); + //see if smartcontext is still usable + auto shared = LongestCommonSubseq(smartcontext, embd_inp); if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common - { + { int found = ArrFindIndexOf(embd_inp,shared); if(found>=0) { @@ -418,7 +418,7 @@ void print_tok_vec(std::vector &embd) { offset_fix = 0; } - + for (int i = n_past; i < current_context_tokens.size(); ++i) { if (current_context_tokens[i] == embd_inp[i-offset_fix]) @@ -438,7 +438,7 @@ void print_tok_vec(std::vector &embd) last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past)); embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past)); - + }else{ smartcontext.clear(); } @@ -453,16 +453,16 @@ void print_tok_vec(std::vector &embd) smartcontext.clear(); } - if(fastforwardok && useSmartContext - && smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold - && embd_inp_len >= SCInpLenThreshold + if(fastforwardok && useSmartContext + && smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold + && embd_inp_len >= SCInpLenThreshold && current_context_tokens.size() - n_past > SCPastLenThreshold) - { + { //determine longest common substring after removing start part int shiftamt = embd_inp.size() * SCTruncationRatio; smartcontext = std::vector(embd_inp.begin() + shiftamt, embd_inp.end()); printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt); - + embd_inp = smartcontext; //if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker. //when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS