Merge remote-tracking branch 'ycros/api-modelbusy-fix' into concedo_experimental

This commit is contained in:
Concedo 2023-07-19 18:32:13 +08:00
commit 2a88d6d3ec

View file

@ -280,7 +280,7 @@ friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a h
maxctx = 2048 maxctx = 2048
maxhordectx = 1024 maxhordectx = 1024
maxhordelen = 256 maxhordelen = 256
modelbusy = False modelbusy = threading.Lock()
defaultport = 5001 defaultport = 5001
KcppVersion = "1.36" KcppVersion = "1.36"
showdebug = True showdebug = True
@ -373,7 +373,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
async def handle_sse_stream(self): async def handle_sse_stream(self):
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache") self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive") self.send_header("Connection", "keep-alive")
self.end_headers() self.end_headers()
@ -495,7 +494,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode()) self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
print("\nGeneration Aborted") print("\nGeneration Aborted")
modelbusy = False
return return
if self.path.endswith('/api/extra/generate/check'): if self.path.endswith('/api/extra/generate/check'):
@ -506,7 +504,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode()) self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
return return
if modelbusy: if not modelbusy.acquire(blocking=False):
self.send_response(503) self.send_response(503)
self.end_headers() self.end_headers()
self.wfile.write(json.dumps({"detail": { self.wfile.write(json.dumps({"detail": {
@ -515,46 +513,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
}}).encode()) }}).encode())
return return
if self.path.endswith('/request'): try:
basic_api_flag = True if self.path.endswith('/request'):
basic_api_flag = True
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True kai_api_flag = True
if self.path.endswith('/api/extra/generate/stream'): if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True kai_api_flag = True
kai_sse_stream_flag = True kai_sse_stream_flag = True
if basic_api_flag or kai_api_flag: if basic_api_flag or kai_api_flag:
genparams = None genparams = None
try: try:
genparams = json.loads(body) genparams = json.loads(body)
except ValueError as e: except ValueError as e:
utfprint("Body Err: " + str(body)) utfprint("Body Err: " + str(body))
return self.send_response(503) return self.send_response(503)
if args.debugmode!=-1: if args.debugmode!=-1:
utfprint("\nInput: " + json.dumps(genparams)) utfprint("\nInput: " + json.dumps(genparams))
modelbusy = True if kai_api_flag:
fullprompt = genparams.get('prompt', "")
else:
fullprompt = genparams.get('text', "")
newprompt = fullprompt
if kai_api_flag: gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
fullprompt = genparams.get('prompt', "") try:
else: self.send_response(200)
fullprompt = genparams.get('text', "") self.end_headers()
newprompt = fullprompt self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) return
try: finally:
self.send_response(200) modelbusy.release()
self.end_headers()
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
modelbusy = False
return
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()