support genkeys in polled streaming

This commit is contained in:
Concedo 2023-09-26 23:46:07 +08:00
parent 6c2134a860
commit 7f112e2cd4

View file

@ -278,8 +278,8 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar=''):
global maxctx, args
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar='', genkey=''):
global maxctx, args, currentusergenkey
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
@ -329,6 +329,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
inputs.stop_sequence[n] = "".encode("UTF-8")
else:
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
currentusergenkey = genkey
ret = handle.generate(inputs,outputs)
if(ret.status==1):
return ret.text.decode("UTF-8","ignore")
@ -359,6 +360,7 @@ showdebug = True
showsamplerwarning = True
showmaxctxwarning = True
exitcounter = 0
currentusergenkey = "" #store a special key so polled streaming works even in multiuser
args = None #global args
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
@ -402,7 +404,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''))
grammar=genparams.get('grammar', ''),
genkey=genparams.get('genkey', ''))
else:
return generate(prompt=newprompt,
@ -424,7 +427,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''))
grammar=genparams.get('grammar', ''),
genkey=genparams.get('genkey', ''))
recvtxt = ""
if stream_flag:
@ -556,6 +560,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
lastc = handle.get_last_token_count()
stopreason = handle.get_last_stop_reason()
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
elif self.path.endswith('/api/extra/generate/check'):
pendtxtStr = ""
if requestsinqueue==0:
pendtxt = handle.get_pending_output()
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
@ -573,7 +586,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return
def do_POST(self):
global modelbusy, requestsinqueue
global modelbusy, requestsinqueue, currentusergenkey
content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
self.path = self.path.rstrip('/')
@ -607,7 +620,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if self.path.endswith('/api/extra/generate/check'):
pendtxtStr = ""
if requestsinqueue==0:
multiuserkey = ""
try:
tempbody = json.loads(body)
multiuserkey = tempbody.get('genkey', "")
except ValueError as e:
multiuserkey = ""
pass
if (multiuserkey!="" and multiuserkey==currentusergenkey) or requestsinqueue==0:
pendtxt = handle.get_pending_output()
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
self.send_response(200)