From adb4df78d6fec29623047bc619c2392245e537f2 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:24:16 +0800 Subject: [PATCH] Added SmartContext mode, a way of prompt context manipulation that avoids frequent context recalculation. --- expose.h | 1 + gpttype_adapter.cpp | 24 +---- koboldcpp.py | 11 ++- llama_adapter.cpp | 29 +----- model_adapter.cpp | 231 +++++++++++++++++++++++++++++++++++++++++++- model_adapter.h | 9 +- 6 files changed, 254 insertions(+), 51 deletions(-) diff --git a/expose.h b/expose.h index 4f255c980..66dfde66a 100644 --- a/expose.h +++ b/expose.h @@ -9,6 +9,7 @@ struct load_model_inputs const char *model_filename; const int n_parts_overwrite = -1; const bool use_mmap; + const bool use_smartcontext; const int clblast_info = 0; }; struct generation_inputs diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 667f8b1d5..c025db5a5 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -35,6 +35,8 @@ static std::vector current_context_tokens; static size_t mem_per_token = 0; static std::vector logits; +static std::vector smartcontext; + inline bool IsNanCheck(float f) { const unsigned int u = *(unsigned int*)&f; @@ -194,27 +196,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; - //fast forward the past based on identical tokens, stop once a divergence is noted - int embd_inp_len = embd_inp.size(); - for (int i = 0; i < current_context_tokens.size(); ++i) - { - if (current_context_tokens[i] == embd_inp[i]) - { - n_past += 1; - last_n_tokens.push_back(current_context_tokens[i]); - } - else - { - break; - } - if ((i + 2) >= embd_inp_len) - { - break; - } - } - - last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past); - embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past); + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true); //if using BLAS and prompt is big enough, switch to single thread and use a huge batch // bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2); diff --git a/koboldcpp.py b/koboldcpp.py index 29d32ad9b..df9e8e6ae 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -16,6 +16,7 @@ class load_model_inputs(ctypes.Structure): ("model_filename", ctypes.c_char_p), ("n_parts_overwrite", ctypes.c_int), ("use_mmap", ctypes.c_bool), + ("use_smartcontext", ctypes.c_bool), ("clblast_info", ctypes.c_int)] class generation_inputs(ctypes.Structure): @@ -65,7 +66,7 @@ def init_library(): handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs -def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False): +def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False): inputs = load_model_inputs() inputs.model_filename = model_filename.encode("UTF-8") inputs.batch_size = batch_size @@ -74,6 +75,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr inputs.n_parts_overwrite = n_parts_overwrite inputs.f16_kv = True inputs.use_mmap = use_mmap + inputs.use_smartcontext = use_smartcontext clblastids = 0 if args.useclblast: clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1]) @@ -383,8 +385,8 @@ def main(args): mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1 modelname = os.path.abspath(ggml_selected_file) - print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}]") - loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap)) + print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]") + loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext) print("Load Model OK: " + str(loadok)) if not loadok: @@ -413,7 +415,7 @@ def main(args): RunServerMultiThreaded(args.host, args.port, embedded_kailite) if __name__ == '__main__': - print("Welcome to KoboldCpp - Version 1.6") # just update version manually + print("Welcome to KoboldCpp - Version 1.7") # just update version manually parser = argparse.ArgumentParser(description='Kobold llama.cpp server') parser.add_argument("model_file", help="Model file to load", nargs="?") portgroup = parser.add_mutually_exclusive_group() #we want to be backwards compatible with the unnamed positional args @@ -430,6 +432,7 @@ if __name__ == '__main__': parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads) parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true') parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true') + parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true') parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true') parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true') compatgroup = parser.add_mutually_exclusive_group() diff --git a/llama_adapter.cpp b/llama_adapter.cpp index 6e437eb6c..20cdd4fcb 100644 --- a/llama_adapter.cpp +++ b/llama_adapter.cpp @@ -31,6 +31,7 @@ static std::string modelname; static llama_context *ctx; static std::vector last_n_tokens; static std::vector current_context_tokens; +static std::vector smartcontext; bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format) { @@ -115,9 +116,10 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out } //truncate to front of the prompt if its too long - if (embd_inp.size() + params.n_predict > params.n_ctx) + int32_t nctx = params.n_ctx; + if (embd_inp.size() + params.n_predict > nctx) { - int offset = embd_inp.size() - params.n_ctx + params.n_predict; + int offset = embd_inp.size() - nctx + params.n_predict; embd_inp = std::vector(embd_inp.begin() + offset, embd_inp.end()); } @@ -131,28 +133,7 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; - //fast forward the past based on identical tokens, stop once a divergence is noted - int embd_inp_len = embd_inp.size(); - int ctxcs = current_context_tokens.size(); - for (int i = 0; i < ctxcs; ++i) - { - if (current_context_tokens[i] == embd_inp[i]) - { - n_past += 1; - last_n_tokens.push_back(current_context_tokens[i]); - } - else - { - break; - } - if ((i + 2) >= embd_inp_len) - { - break; - } - } - - last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past); - embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past); + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true); //if using BLAS and prompt is big enough, switch to single thread and use a huge batch bool blasmode = (embd_inp.size() >= 32 && ggml_cpu_has_blas()); diff --git a/model_adapter.cpp b/model_adapter.cpp index 0a7c1fd20..5a41a6717 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -28,6 +28,10 @@ double timer_check() } void print_tok_vec(std::vector &embd) +{ + print_tok_vec(embd,nullptr); +} +void print_tok_vec(std::vector &embd, std::map * decoder) { std::cout << "["; bool first = true; @@ -38,7 +42,14 @@ void print_tok_vec(std::vector &embd) std::cout << ','; } first = false; - std::cout << i; + if(decoder) + { + std::cout << (*decoder)[i]; + } + else + { + std::cout << i; + } } std::cout << "]\n"; } @@ -125,4 +136,222 @@ void print_tok_vec(std::vector &embd) fin.close(); return fileformat; + } + + bool ArrStartWith(const std::vector targetArray, const std::vector searchSeq) + { + int ss = searchSeq.size(); + if(targetArray.size() targetArray, const std::vector searchSeq) + { + int ss = searchSeq.size(); + int tas = targetArray.size(); + if(tas= tas || targetArray[i + srch] != searchSeq[srch]) + { + fail = true; + break; + } + } + if(!fail) + { + return i; + } + } + return -1; + } + + std::vector LongestCommonSubseq(const std::vector x, const std::vector y) + { + int m = x.size(), n = y.size(); + + //int LCSuff[m+1][n+1]; + std::vector> LCSuff(m+1, std::vector(n+1)); + + for (int j = 0; j <= n; j++) + LCSuff[0][j] = 0; + for (int i = 0; i <= m; i++) + LCSuff[i][0] = 0; + + for (int i = 1; i <= m; i++) + { + for (int j = 1; j <= n; j++) + { + if (x[i - 1] == y[j - 1]) + LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1; + else + LCSuff[i][j] = 0; + } + } + + std::vector longest; + for (int i = 1; i <= m; i++) + { + for (int j = 1; j <= n; j++) + { + if (LCSuff[i][j] > longest.size()) + { + auto off1 = ((i - LCSuff[i][j] + 1) - 1); + auto off2 = off1 + LCSuff[i][j]; + longest.clear(); + // std::vector().swap(longest); + longest = std::vector(x.begin() + off1, x.begin() + off2); + // x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]); + } + } + } + return longest; + } + + void ContextFastForward(std::vector ¤t_context_tokens, std::vector &embd_inp, + int &n_past, std::vector &last_n_tokens, const int nctx, std::vector &smartcontext, bool useSmartContext) + { + const int SCTokThreshold = 32; //how many tokens of similarity triggers smartcontext + const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext + const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext + const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context + const float SCTruncationRatio = 0.5; //ratio for how many tokens to fast forward + + // printf("\nORIGINAL CTX:\n"); + // print_tok_vec(current_context_tokens); + // printf("\nORIGINAL EMBD:\n"); + // print_tok_vec(embd_inp); + + //fast forward the past based on identical tokens, stop once a divergence is noted + int embd_inp_len = embd_inp.size(); + for (int i = 0; i < current_context_tokens.size(); ++i) + { + if (current_context_tokens[i] == embd_inp[i]) + { + n_past += 1; + last_n_tokens.push_back(current_context_tokens[i]); + } + else + { + break; + } + if ((i + 2) >= embd_inp_len) + { + break; + } + } + + + last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past); + embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past); + embd_inp_len = embd_inp.size(); + + //smart context mode, detect if we have a shifted context at max length + //requirement: previous context was at least nctx/2 longer than current, + //mode is on, and current context already maxed. + + // printf("\nconds: %d %d %d\n",current_context_tokens.size() >= nctx*0.8 + // ,embd_inp_len >= nctx*0.6 ,current_context_tokens.size() - n_past > nctx*0.5); + // printf("csiz:%d par:%d eilen:%d np:%d",current_context_tokens.size(), (int)(nctx*0.8),embd_inp_len,n_past); + + if (useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold) + { + // printf("curfullcontext:\n"); + // print_tok_vec(current_context_tokens); + + //see if smartcontext is still usable + // printf("smartctx:\n"); + // print_tok_vec(smartcontext); + // printf("embinp:\n"); + // print_tok_vec(embd_inp); + auto shared = LongestCommonSubseq(smartcontext, embd_inp); + if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common + { + int found = ArrFindIndexOf(embd_inp,shared); + if(found>=0) + { + auto trimmed = std::vector(embd_inp.begin() + found, embd_inp.end()); + embd_inp = trimmed; + embd_inp_len = embd_inp.size(); + // printf("trimmed:\n"); + // print_tok_vec(embd_inp,&vocab.id_to_token); + printf("\n[Reusing Smart Context: %d allowance remaining]", found); + + int old_n_past = n_past; + int offset_fix = old_n_past; + if (current_context_tokens[n_past] != embd_inp[0]) + { + offset_fix = 0; + } + + for (int i = n_past; i < current_context_tokens.size(); ++i) + { + //printf("\n%s and %s\n",vocab.id_to_token[current_context_tokens[i]].c_str(), vocab.id_to_token[embd_inp[i-offset_fix]].c_str()); + if (current_context_tokens[i] == embd_inp[i-offset_fix]) + { + n_past += 1; + last_n_tokens.push_back(current_context_tokens[i]); + } + else + { + break; + } + if ((i + 2) >= embd_inp_len) + { + break; + } + } + + last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past)); + embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past)); + // printf("np:%d newembinp: \n",n_past); + // print_tok_vec(embd_inp); + }else{ + smartcontext.clear(); + } + } + else + { + smartcontext.clear(); + } + } + else + { + smartcontext.clear(); + } + + if(useSmartContext + && smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold + && embd_inp_len >= SCInpLenThreshold + && current_context_tokens.size() - n_past > SCPastLenThreshold) + { + //determine longest common substring after removing start part + int shiftamt = embd_inp.size() * SCTruncationRatio; + smartcontext = std::vector(embd_inp.begin() + shiftamt, embd_inp.end()); + printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt); + // printf("smartctx:\n"); + // print_tok_vec(smartcontext,&vocab.id_to_token); + embd_inp = smartcontext; + //if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker. + //when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS + //remove all tokens between start part and start of LCS in new prompt, thus avoiding shift + //if LCS not found or mismatched, regenerate. chop new prompt and repeat from step B + } } \ No newline at end of file diff --git a/model_adapter.h b/model_adapter.h index 5cfe74057..4f1036038 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -44,5 +44,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o void timer_start(); double timer_check(); void print_tok_vec(std::vector &embd); +void print_tok_vec(std::vector &embd, std::map * decoder); void print_tok_vec(std::vector &embd); -FileFormat check_file_format(const std::string & fname); \ No newline at end of file +std::vector LongestCommonSubseq(const std::vector x, const std::vector y); +bool ArrStartWith(const std::vector targetArray, const std::vector searchSeq); +int ArrFindIndexOf(const std::vector targetArray, const std::vector searchSeq); + +FileFormat check_file_format(const std::string & fname); +void ContextFastForward(std::vector ¤t_context_tokens, std::vector &embd_inp, + int &n_past, std::vector &last_n_tokens, const int nctx, std::vector &smartcontext, const bool useSmartContext); \ No newline at end of file