diff --git a/expose.cpp b/expose.cpp index 8a607de5f..833b5f24d 100644 --- a/expose.cpp +++ b/expose.cpp @@ -24,6 +24,7 @@ extern "C" { { const int seed; const char * prompt; + const int max_context_length; const int max_length; const float temperature; const int top_k; @@ -81,6 +82,7 @@ extern "C" { api_params.temp = inputs.temperature; api_params.repeat_last_n = inputs.rep_pen_range; api_params.repeat_penalty = inputs.rep_pen; + api_params.n_ctx = inputs.max_context_length; bool reset_state = inputs.reset_state; if(api_n_past==0) @@ -151,7 +153,8 @@ extern "C" { std::mt19937 api_rng(api_params.seed); std::string concat_output = ""; - printf("\nProcessing: "); + bool startedsampling = false; + printf("\nProcessing Prompt: "); while (remaining_tokens > 0) { gpt_vocab::id id = 0; @@ -183,6 +186,12 @@ extern "C" { const float repeat_penalty = api_params.repeat_penalty; const int n_vocab = api_model.hparams.n_vocab; + if(!startedsampling) + { + startedsampling = true; + printf("\nGenerating: "); + } + { // set the logit of the eos token (2) to zero to avoid sampling it api_logits[api_logits.size() - n_vocab + EOS_TOKEN_ID] = 0; diff --git a/llama_for_kobold.py b/llama_for_kobold.py index 6e341bee3..5df75a043 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -17,6 +17,7 @@ class load_model_inputs(ctypes.Structure): class generation_inputs(ctypes.Structure): _fields_ = [("seed", ctypes.c_int), ("prompt", ctypes.c_char_p), + ("max_context_length", ctypes.c_int), ("max_length", ctypes.c_int), ("temperature", ctypes.c_float), ("top_k", ctypes.c_int), @@ -37,20 +38,21 @@ handle.load_model.restype = ctypes.c_bool handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs -def load_model(model_filename,batch_size=8,max_context_length=2048,n_parts_overwrite=-1): +def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1): inputs = load_model_inputs() inputs.model_filename = model_filename.encode("UTF-8") inputs.batch_size = batch_size - inputs.max_context_length = max_context_length + inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten inputs.threads = os.cpu_count() inputs.n_parts_overwrite = n_parts_overwrite ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,reset_state=True): +def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,reset_state=True): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") + inputs.max_context_length = max_context_length # this will resize the context buffer if changed inputs.max_length = max_length inputs.temperature = temperature inputs.top_k = top_k @@ -125,44 +127,59 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.wfile.write(rp.encode()) return + def do_POST(self): + global modelbusy + global last_context content_length = int(self.headers['Content-Length']) - body = self.rfile.read(content_length) + body = self.rfile.read(content_length) + + if modelbusy: + self.send_response(503) + self.end_headers() + self.wfile.write(json.dumps({"detail": { + "msg": "Server is busy; please try again later.", + "type": "service_unavailable", + }}).encode()) + return + + basic_api_flag = False + kai_api_flag = False + if self.path.endswith('/request') or self.path.endswith('/request'): + basic_api_flag = True if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/') or self.path.endswith('/api/v1/generate') or self.path.endswith('/api/latest/generate'): - global modelbusy - global last_context - if modelbusy: + kai_api_flag = True + + if basic_api_flag or kai_api_flag: + genparams = None + try: + genparams = json.loads(body) + except ValueError as e: self.send_response(503) self.end_headers() - self.wfile.write(json.dumps({"detail": { - "msg": "Server is busy; please try again later.", - "type": "service_unavailable", - }}).encode()) - return - else: - modelbusy = True - genparams = None - try: - genparams = json.loads(body) - except ValueError as e: - self.send_response(503) - self.end_headers() - return - - print("\nInput: " + json.dumps(genparams)) - fresh_state = True + return + print("\nInput: " + json.dumps(genparams)) + fresh_state = True + modelbusy = True + if kai_api_flag: fullprompt = genparams.get('prompt', "") - newprompt = fullprompt - if last_context!="" and newprompt.startswith(last_context): - fresh_state = False - newprompt = newprompt[len(last_context):] - print("Resuming state, new input len: " + str(len(newprompt))) - #print("trimmed: " + newprompt) + else: + fullprompt = genparams.get('text', "") + newprompt = fullprompt + if last_context!="" and newprompt.startswith(last_context): + fresh_state = False + newprompt = newprompt[len(last_context):] + print("Resuming state, new input len: " + str(len(newprompt))) + + + recvtxt = "" + if kai_api_flag: recvtxt = generate( prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), max_length=genparams.get('max_length', 50), temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 100), + top_k=genparams.get('top_k', 200), top_p=genparams.get('top_p', 0.85), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), @@ -174,9 +191,28 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): res = {"results": [{"text": recvtxt}]} self.send_response(200) self.end_headers() + self.wfile.write(json.dumps(res).encode()) + else: + recvtxt = generate( + prompt=newprompt, + max_length=genparams.get('max', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 200), + top_p=genparams.get('top_p', 0.85), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=-1, + reset_state=fresh_state + ) + print("\nOutput: " + recvtxt) + last_context = fullprompt + recvtxt + res = {"data": {"seqs":[recvtxt]}} + self.send_response(200) + self.end_headers() self.wfile.write(json.dumps(res).encode()) - modelbusy = False - return + modelbusy = False + return + self.send_response(404) self.end_headers() @@ -265,7 +301,7 @@ if __name__ == '__main__': mdl_nparts += 1 modelname = os.path.abspath(sys.argv[1]) print("Loading model: " + modelname) - loadok = load_model(modelname,24,maxctx,mdl_nparts) + loadok = load_model(modelname,16,maxctx,mdl_nparts) print("Load Model OK: " + str(loadok)) #friendlymodelname = Path(modelname).stem ### this wont work on local kobold api, so we must hardcode a known HF model name diff --git a/llamacpp.dll b/llamacpp.dll index 8463cc428..93ffe71a2 100644 Binary files a/llamacpp.dll and b/llamacpp.dll differ