diff --git a/expose.cpp b/expose.cpp index b4df04b8c..0c0c33616 100644 --- a/expose.cpp +++ b/expose.cpp @@ -10,6 +10,21 @@ #include "main.cpp" #include "extra.h" +void print_tok_vec(std::vector & embd) +{ + std::cout << "["; + bool first = true; + for (auto i: embd) { + if(!first) + { + std::cout << ','; + } + first = false; + std::cout << i; + } + std::cout << "]"; +} + extern "C" { struct load_model_inputs @@ -31,7 +46,6 @@ extern "C" { const float top_p; const float rep_pen; const int rep_pen_range; - const bool reset_state = true; //determines if we can continue off the previous prompt state }; struct generation_outputs { @@ -43,12 +57,12 @@ extern "C" { llama_context_params ctx_params; gpt_params params; int n_past = 0; - llama_token old_embd_id = -1; int n_threads = 4; int n_batch = 8; std::string model; llama_context * ctx; std::vector last_n_tokens; + std::vector current_context_tokens; bool load_model(const load_model_inputs inputs) { @@ -80,6 +94,10 @@ extern "C" { printf("\n---\nWarning: Your model is using an OUTDATED format. Please reconvert it for better results!\n"); } + //determine mem per token + const std::vector tmp = { 0, 1, 2, 3 }; + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + return true; } @@ -96,12 +114,6 @@ extern "C" { params.n_ctx = inputs.max_context_length; params.n_batch = n_batch; params.n_threads = n_threads; - - bool reset_state = inputs.reset_state; - if(n_past==0) - { - reset_state = true; - } if(params.repeat_last_n<1) { @@ -115,12 +127,9 @@ extern "C" { { params.seed = time(NULL); } - - if(reset_state) - { - params.prompt.insert(0, 1, ' '); - } - + + params.prompt.insert(0, 1, ' '); + // tokenize the prompt std::vector embd_inp; if(legacy_format) @@ -135,7 +144,10 @@ extern "C" { if (embd_inp.size() + params.n_predict > params.n_ctx) { int offset = embd_inp.size() - params.n_ctx + params.n_predict; embd_inp = std::vector(embd_inp.begin() + offset, embd_inp.end()); - } + } + + //determine how much npast we have to rewind from the current state + std::vector embd; int last_n_size = params.repeat_last_n; @@ -145,26 +157,30 @@ extern "C" { // std::string tst = " "; // char * tst2 = (char*)tst.c_str(); // gpt_print_usage(1,&tst2,params); + + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + n_past = 0; - if(reset_state) + //fast forward the past based on identical tokens, stop once a divergence is noted + for(int i=0;i tmp = { 0, 1, 2, 3 }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - n_past = 0; - } - else - { - //strip out the reset token (1) at the start of the embedding - if(embd_inp.size()>0) + if(current_context_tokens[i]==embd_inp[0]) { + n_past += 1; embd_inp.erase(embd_inp.begin()); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(current_context_tokens[i]); } - if(old_embd_id!=-1) + else { - embd.push_back(old_embd_id); + break; + } + if(embd_inp.size()<=1) + { + break; } } + current_context_tokens.resize(n_past); int remaining_tokens = params.n_predict; int input_consumed = 0; @@ -180,11 +196,8 @@ extern "C" { // predict if (embd.size() > 0) { - printf("|"); - // for (auto i: embd) { - // std::cout << i << ','; - // } - // printf("\nnp:%d embd:%d",n_past,embd.size()); + printf("|"); + //printf("\nnp:%d embd:%d txt:%s",n_past,embd.size(),llama_token_to_str(ctx, embd[0])); if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "Failed to predict\n"); @@ -222,13 +235,12 @@ extern "C" { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); + current_context_tokens.push_back(id); } // add it to the context - old_embd_id = id; embd.push_back(id); - // decrement remaining sampling budget --remaining_tokens; //printf("\nid:%d word:%s\n",id,llama_token_to_str(ctx, id)); @@ -239,10 +251,10 @@ extern "C" { // some user input remains from prompt or interaction, forward it to processing while ((int) embd_inp.size() > input_consumed) { - old_embd_id = embd_inp[input_consumed]; embd.push_back(embd_inp[input_consumed]); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(embd_inp[input_consumed]); + current_context_tokens.push_back(embd_inp[input_consumed]); ++input_consumed; if ((int) embd.size() >= params.n_batch) { diff --git a/llama_for_kobold.py b/llama_for_kobold.py index 5df75a043..5b4906f6b 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -23,8 +23,7 @@ class generation_inputs(ctypes.Structure): ("top_k", ctypes.c_int), ("top_p", ctypes.c_float), ("rep_pen", ctypes.c_float), - ("rep_pen_range", ctypes.c_int), - ("reset_state", ctypes.c_bool)] + ("rep_pen_range", ctypes.c_int)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -48,7 +47,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,reset_state=True): +def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -60,7 +59,6 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.seed = seed - inputs.reset_state = reset_state ret = handle.generate(inputs,outputs) if(ret.status==1): return ret.text.decode("UTF-8") @@ -80,7 +78,6 @@ maxctx = 2048 maxlen = 128 modelbusy = False port = 5001 -last_context = "" embedded_kailite = None class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): @@ -130,7 +127,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): def do_POST(self): global modelbusy - global last_context content_length = int(self.headers['Content-Length']) body = self.rfile.read(content_length) @@ -159,18 +155,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.end_headers() return print("\nInput: " + json.dumps(genparams)) - fresh_state = True + modelbusy = True if kai_api_flag: fullprompt = genparams.get('prompt', "") else: fullprompt = genparams.get('text', "") newprompt = fullprompt - if last_context!="" and newprompt.startswith(last_context): - fresh_state = False - newprompt = newprompt[len(last_context):] - print("Resuming state, new input len: " + str(len(newprompt))) - + recvtxt = "" if kai_api_flag: @@ -183,11 +175,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): top_p=genparams.get('top_p', 0.85), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), - seed=-1, - reset_state=fresh_state + seed=-1 ) print("\nOutput: " + recvtxt) - last_context = fullprompt + recvtxt res = {"results": [{"text": recvtxt}]} self.send_response(200) self.end_headers() @@ -201,11 +191,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): top_p=genparams.get('top_p', 0.85), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), - seed=-1, - reset_state=fresh_state + seed=-1 ) print("\nOutput: " + recvtxt) - last_context = fullprompt + recvtxt res = {"data": {"seqs":[recvtxt]}} self.send_response(200) self.end_headers() diff --git a/llamacpp.dll b/llamacpp.dll index c487cba2a..de83ebb1f 100644 Binary files a/llamacpp.dll and b/llamacpp.dll differ