diff --git a/koboldcpp.py b/koboldcpp.py index 8ca30f9de..888d29f8a 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -273,7 +273,7 @@ friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a h maxctx = 2048 maxhordectx = 1024 maxhordelen = 256 -modelbusy = False +modelbusy = threading.Lock() defaultport = 5001 KcppVersion = "1.35" showdebug = True @@ -365,7 +365,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): async def handle_sse_stream(self): self.send_response(200) - self.send_header("Content-Type", "text/event-stream") self.send_header("Cache-Control", "no-cache") self.send_header("Connection", "keep-alive") self.end_headers() @@ -487,7 +486,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.end_headers() self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode()) print("\nGeneration Aborted") - modelbusy = False return if self.path.endswith('/api/extra/generate/check'): @@ -498,7 +496,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode()) return - if modelbusy: + if not modelbusy.acquire(blocking=False): self.send_response(503) self.end_headers() self.wfile.write(json.dumps({"detail": { @@ -507,46 +505,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): }}).encode()) return - if self.path.endswith('/request'): - basic_api_flag = True + try: + if self.path.endswith('/request'): + basic_api_flag = True - if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): - kai_api_flag = True + if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): + kai_api_flag = True - if self.path.endswith('/api/extra/generate/stream'): - kai_api_flag = True - kai_sse_stream_flag = True + if self.path.endswith('/api/extra/generate/stream'): + kai_api_flag = True + kai_sse_stream_flag = True - if basic_api_flag or kai_api_flag: - genparams = None - try: - genparams = json.loads(body) - except ValueError as e: - utfprint("Body Err: " + str(body)) - return self.send_response(503) + if basic_api_flag or kai_api_flag: + genparams = None + try: + genparams = json.loads(body) + except ValueError as e: + utfprint("Body Err: " + str(body)) + return self.send_response(503) - if args.debugmode!=-1: - utfprint("\nInput: " + json.dumps(genparams)) + if args.debugmode!=-1: + utfprint("\nInput: " + json.dumps(genparams)) - modelbusy = True + if kai_api_flag: + fullprompt = genparams.get('prompt', "") + else: + fullprompt = genparams.get('text', "") + newprompt = fullprompt - if kai_api_flag: - fullprompt = genparams.get('prompt', "") - else: - fullprompt = genparams.get('text', "") - newprompt = fullprompt + gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) + try: + self.send_response(200) + self.end_headers() + self.wfile.write(json.dumps(gen).encode()) + except: + print("Generate: The response could not be sent, maybe connection was terminated?") - gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) - try: - self.send_response(200) - self.end_headers() - self.wfile.write(json.dumps(gen).encode()) - except: - print("Generate: The response could not be sent, maybe connection was terminated?") - - modelbusy = False - - return + return + finally: + modelbusy.release() self.send_response(404) self.end_headers()