API: Replace modelbusy bool with a lock.

- Also remove duplicate Content-Type header on streams responses.
This commit is contained in:
Ycros 2023-07-18 20:09:50 +10:00
parent 5941514e95
commit fd90d52127

View file

@ -273,7 +273,7 @@ friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a h
maxctx = 2048
maxhordectx = 1024
maxhordelen = 256
modelbusy = False
modelbusy = threading.Lock()
defaultport = 5001
KcppVersion = "1.35"
showdebug = True
@ -365,7 +365,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
async def handle_sse_stream(self):
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.end_headers()
@ -487,7 +486,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.end_headers()
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
print("\nGeneration Aborted")
modelbusy = False
return
if self.path.endswith('/api/extra/generate/check'):
@ -498,7 +496,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
return
if modelbusy:
if not modelbusy.acquire(blocking=False):
self.send_response(503)
self.end_headers()
self.wfile.write(json.dumps({"detail": {
@ -507,46 +505,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
}}).encode())
return
if self.path.endswith('/request'):
basic_api_flag = True
try:
if self.path.endswith('/request'):
basic_api_flag = True
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True
if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True
kai_sse_stream_flag = True
if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True
kai_sse_stream_flag = True
if basic_api_flag or kai_api_flag:
genparams = None
try:
genparams = json.loads(body)
except ValueError as e:
utfprint("Body Err: " + str(body))
return self.send_response(503)
if basic_api_flag or kai_api_flag:
genparams = None
try:
genparams = json.loads(body)
except ValueError as e:
utfprint("Body Err: " + str(body))
return self.send_response(503)
if args.debugmode!=-1:
utfprint("\nInput: " + json.dumps(genparams))
if args.debugmode!=-1:
utfprint("\nInput: " + json.dumps(genparams))
modelbusy = True
if kai_api_flag:
fullprompt = genparams.get('prompt', "")
else:
fullprompt = genparams.get('text', "")
newprompt = fullprompt
if kai_api_flag:
fullprompt = genparams.get('prompt', "")
else:
fullprompt = genparams.get('text', "")
newprompt = fullprompt
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
try:
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
try:
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
modelbusy = False
return
return
finally:
modelbusy.release()
self.send_response(404)
self.end_headers()