diff --git a/expose.cpp b/expose.cpp index 2df992bac..e10ec876b 100644 --- a/expose.cpp +++ b/expose.cpp @@ -28,7 +28,8 @@ extern "C" { const int top_k; const float top_p; const float rep_pen; - const int rep_pen_range; + const int rep_pen_range; + const bool reset_state = true; //determines if we can continue off the previous prompt state }; struct generation_outputs { @@ -40,7 +41,10 @@ extern "C" { gpt_vocab api_vocab; llama_model api_model; int api_n_past = 0; + gpt_vocab::id old_embd_id = -1; std::vector api_logits; + std::vector last_n_tokens; + size_t mem_per_token = 0; bool load_model(const load_model_inputs inputs) { @@ -69,6 +73,12 @@ extern "C" { api_params.temp = inputs.temperature; api_params.repeat_last_n = inputs.rep_pen_range; api_params.repeat_penalty = inputs.rep_pen; + + bool reset_state = inputs.reset_state; + if(api_n_past==0) + { + reset_state = true; + } if(api_params.repeat_last_n<1) { @@ -88,42 +98,61 @@ extern "C" { // char * tst2 = (char*)tst.c_str(); // gpt_print_usage(1,&tst2,api_params); - api_params.prompt.insert(0, 1, ' '); + if(reset_state) + { + api_params.prompt.insert(0, 1, ' '); + mem_per_token = 0; + } // tokenize the prompt std::vector embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true); api_params.n_predict = std::min(api_params.n_predict, api_model.hparams.n_ctx - (int)embd_inp.size()); std::vector embd; - size_t mem_per_token = 0; - llama_eval(api_model, api_params.n_threads, 0, {0, 1, 2, 3}, api_logits, mem_per_token); - + int last_n_size = api_params.repeat_last_n; - std::vector last_n_tokens(last_n_size); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + last_n_tokens.resize(last_n_size); + if(reset_state) + { + llama_eval(api_model, api_params.n_threads, 0, {0, 1, 2, 3}, api_logits, mem_per_token); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + api_n_past = 0; + }else{ + //strip out the reset token (1) at the start of the embedding + if(embd_inp.size()>0) + { + embd_inp.erase(embd_inp.begin()); + } + if(old_embd_id!=-1) + { + embd.push_back(old_embd_id); + } + } + int remaining_tokens = api_params.n_predict; int input_consumed = 0; std::mt19937 api_rng(api_params.seed); - - std::string concat_output = ""; + std::string concat_output = ""; while (remaining_tokens > 0) { - gpt_vocab::id id = 0; + gpt_vocab::id id = 0; // predict if (embd.size() > 0) { - + // for (auto i: embd) { + // std::cout << i << ','; + // } + //printf("\nnp:%d embd:%d mem:%d",api_n_past,embd.size(),mem_per_token); if (!llama_eval(api_model, api_params.n_threads, api_n_past, embd, api_logits, mem_per_token)) { fprintf(stderr, "Failed to predict\n"); - _snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s",""); + snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; return output; } } api_n_past += embd.size(); - embd.clear(); - + embd.clear(); if (embd_inp.size() <= input_consumed) { // out of user input, sample next token @@ -148,11 +177,12 @@ extern "C" { } // add it to the context + old_embd_id = id; embd.push_back(id); // decrement remaining sampling budget --remaining_tokens; - + //printf("\nid:%d word:%s\n",id,api_vocab.id_to_token[id].c_str()); concat_output += api_vocab.id_to_token[id].c_str(); } else @@ -160,6 +190,7 @@ extern "C" { // some user input remains from prompt or interaction, forward it to processing while (embd_inp.size() > input_consumed) { + old_embd_id = embd_inp[input_consumed]; embd.push_back(embd_inp[input_consumed]); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(embd_inp[input_consumed]); @@ -175,7 +206,7 @@ extern "C" { //printf("output: %s",concat_output.c_str()); output.status = 1; - _snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s",concat_output.c_str()); + snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; } } \ No newline at end of file diff --git a/llama_for_kobold.py b/llama_for_kobold.py index be02d5022..7b2cc8f0a 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -21,7 +21,8 @@ class generation_inputs(ctypes.Structure): ("top_k", ctypes.c_int), ("top_p", ctypes.c_float), ("rep_pen", ctypes.c_float), - ("rep_pen_range", ctypes.c_int)] + ("rep_pen_range", ctypes.c_int), + ("reset_state", ctypes.c_bool)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -45,7 +46,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_pa ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1): +def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,reset_state=True): inputs = generation_inputs() outputs = generation_outputs() inputs.prompt = prompt.encode("UTF-8") @@ -56,6 +57,7 @@ def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1 inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.seed = seed + inputs.reset_state = reset_state ret = handle.generate(inputs,outputs) if(ret.status==1): return ret.text.decode("UTF-8") @@ -75,6 +77,7 @@ maxctx = 1024 maxlen = 256 modelbusy = False port = 5001 +last_context = "" class ServerRequestHandler(http.server.BaseHTTPRequestHandler): @@ -120,7 +123,8 @@ class ServerRequestHandler(http.server.BaseHTTPRequestHandler): content_length = int(self.headers['Content-Length']) body = self.rfile.read(content_length) if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/'): - global modelbusy + global modelbusy + global last_context if modelbusy: self.send_response(503) self.end_headers() @@ -140,17 +144,26 @@ class ServerRequestHandler(http.server.BaseHTTPRequestHandler): return print("\nInput: " + json.dumps(genparams)) + fresh_state = True + fullprompt = genparams.get('prompt', "") + newprompt = fullprompt + if last_context!="" and newprompt.startswith(last_context): + fresh_state = False + newprompt = newprompt[len(last_context):] + #print("trimmed: " + newprompt) recvtxt = generate( - prompt=genparams.get('prompt', ""), + prompt=newprompt, max_length=genparams.get('max_length', 50), temperature=genparams.get('temperature', 0.8), top_k=genparams.get('top_k', 100), top_p=genparams.get('top_p', 0.85), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), - seed=-1 + seed=-1, + reset_state=fresh_state ) print("\nOutput: " + recvtxt) + last_context = fullprompt + recvtxt res = {"results": [{"text": recvtxt}]} self.send_response(200) self.end_headers() @@ -241,7 +254,7 @@ if __name__ == '__main__': mdl_nparts += 1 modelname = os.path.abspath(sys.argv[1]) print("Loading model: " + modelname) - loadok = load_model(modelname,128,maxctx,4,mdl_nparts) + loadok = load_model(modelname,24,maxctx,4,mdl_nparts) print("Load Model OK: " + str(loadok)) if loadok: diff --git a/llamacpp.dll b/llamacpp.dll index baac468e7..4d3bb22dd 100644 Binary files a/llamacpp.dll and b/llamacpp.dll differ