diff --git a/koboldcpp.py b/koboldcpp.py index 531e89f43..e8c5836cb 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -278,8 +278,8 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar=''): - global maxctx, args +def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar='', genkey=''): + global maxctx, args, currentusergenkey inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -329,6 +329,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_ inputs.stop_sequence[n] = "".encode("UTF-8") else: inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8") + currentusergenkey = genkey ret = handle.generate(inputs,outputs) if(ret.status==1): return ret.text.decode("UTF-8","ignore") @@ -359,6 +360,7 @@ showdebug = True showsamplerwarning = True showmaxctxwarning = True exitcounter = 0 +currentusergenkey = "" #store a special key so polled streaming works even in multiuser args = None #global args class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): @@ -402,7 +404,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): stop_sequence=genparams.get('stop_sequence', []), use_default_badwordsids=genparams.get('use_default_badwordsids', True), stream_sse=stream_flag, - grammar=genparams.get('grammar', '')) + grammar=genparams.get('grammar', ''), + genkey=genparams.get('genkey', '')) else: return generate(prompt=newprompt, @@ -424,7 +427,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): stop_sequence=genparams.get('stop_sequence', []), use_default_badwordsids=genparams.get('use_default_badwordsids', True), stream_sse=stream_flag, - grammar=genparams.get('grammar', '')) + grammar=genparams.get('grammar', ''), + genkey=genparams.get('genkey', '')) recvtxt = "" if stream_flag: @@ -556,6 +560,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): lastc = handle.get_last_token_count() stopreason = handle.get_last_stop_reason() response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode()) + + elif self.path.endswith('/api/extra/generate/check'): + pendtxtStr = "" + if requestsinqueue==0: + pendtxt = handle.get_pending_output() + pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") + response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode()) + + elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode()) @@ -573,7 +586,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): return def do_POST(self): - global modelbusy, requestsinqueue + global modelbusy, requestsinqueue, currentusergenkey content_length = int(self.headers['Content-Length']) body = self.rfile.read(content_length) self.path = self.path.rstrip('/') @@ -607,7 +620,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if self.path.endswith('/api/extra/generate/check'): pendtxtStr = "" - if requestsinqueue==0: + multiuserkey = "" + try: + tempbody = json.loads(body) + multiuserkey = tempbody.get('genkey', "") + except ValueError as e: + multiuserkey = "" + pass + + if (multiuserkey!="" and multiuserkey==currentusergenkey) or requestsinqueue==0: pendtxt = handle.get_pending_output() pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") self.send_response(200)