support genkeys in polled streaming
This commit is contained in:
parent
6c2134a860
commit
7f112e2cd4
1 changed files with 27 additions and 6 deletions
33
koboldcpp.py
33
koboldcpp.py
|
@ -278,8 +278,8 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar=''):
|
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar='', genkey=''):
|
||||||
global maxctx, args
|
global maxctx, args, currentusergenkey
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
inputs.prompt = prompt.encode("UTF-8")
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
|
@ -329,6 +329,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
|
||||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||||
else:
|
else:
|
||||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
||||||
|
currentusergenkey = genkey
|
||||||
ret = handle.generate(inputs,outputs)
|
ret = handle.generate(inputs,outputs)
|
||||||
if(ret.status==1):
|
if(ret.status==1):
|
||||||
return ret.text.decode("UTF-8","ignore")
|
return ret.text.decode("UTF-8","ignore")
|
||||||
|
@ -359,6 +360,7 @@ showdebug = True
|
||||||
showsamplerwarning = True
|
showsamplerwarning = True
|
||||||
showmaxctxwarning = True
|
showmaxctxwarning = True
|
||||||
exitcounter = 0
|
exitcounter = 0
|
||||||
|
currentusergenkey = "" #store a special key so polled streaming works even in multiuser
|
||||||
args = None #global args
|
args = None #global args
|
||||||
|
|
||||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
@ -402,7 +404,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
stop_sequence=genparams.get('stop_sequence', []),
|
||||||
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
||||||
stream_sse=stream_flag,
|
stream_sse=stream_flag,
|
||||||
grammar=genparams.get('grammar', ''))
|
grammar=genparams.get('grammar', ''),
|
||||||
|
genkey=genparams.get('genkey', ''))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return generate(prompt=newprompt,
|
return generate(prompt=newprompt,
|
||||||
|
@ -424,7 +427,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
stop_sequence=genparams.get('stop_sequence', []),
|
||||||
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
||||||
stream_sse=stream_flag,
|
stream_sse=stream_flag,
|
||||||
grammar=genparams.get('grammar', ''))
|
grammar=genparams.get('grammar', ''),
|
||||||
|
genkey=genparams.get('genkey', ''))
|
||||||
|
|
||||||
recvtxt = ""
|
recvtxt = ""
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
|
@ -556,6 +560,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
lastc = handle.get_last_token_count()
|
lastc = handle.get_last_token_count()
|
||||||
stopreason = handle.get_last_stop_reason()
|
stopreason = handle.get_last_stop_reason()
|
||||||
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
|
response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason, "queue":requestsinqueue, "idle":(0 if modelbusy.locked() else 1)}).encode())
|
||||||
|
|
||||||
|
elif self.path.endswith('/api/extra/generate/check'):
|
||||||
|
pendtxtStr = ""
|
||||||
|
if requestsinqueue==0:
|
||||||
|
pendtxt = handle.get_pending_output()
|
||||||
|
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||||
|
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||||
|
|
||||||
|
|
||||||
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
||||||
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
||||||
|
|
||||||
|
@ -573,7 +586,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
global modelbusy, requestsinqueue
|
global modelbusy, requestsinqueue, currentusergenkey
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = self.rfile.read(content_length)
|
body = self.rfile.read(content_length)
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
|
@ -607,7 +620,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
if self.path.endswith('/api/extra/generate/check'):
|
if self.path.endswith('/api/extra/generate/check'):
|
||||||
pendtxtStr = ""
|
pendtxtStr = ""
|
||||||
if requestsinqueue==0:
|
multiuserkey = ""
|
||||||
|
try:
|
||||||
|
tempbody = json.loads(body)
|
||||||
|
multiuserkey = tempbody.get('genkey', "")
|
||||||
|
except ValueError as e:
|
||||||
|
multiuserkey = ""
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (multiuserkey!="" and multiuserkey==currentusergenkey) or requestsinqueue==0:
|
||||||
pendtxt = handle.get_pending_output()
|
pendtxt = handle.get_pending_output()
|
||||||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue