diff --git a/koboldcpp.py b/koboldcpp.py index 772023221..bfb534c13 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,7 +5,7 @@ import ctypes import os import argparse -import json, sys, time, asyncio +import json, sys, time, asyncio, socket from aiohttp import web from concurrent.futures import ThreadPoolExecutor @@ -255,8 +255,8 @@ class ServerRequestHandler: try: return res - except: - print("Generate: Error while generating") + except Exception as e: + print(f"Generate: Error while generating {e}") async def send_sse_event(self, response, event, data): @@ -273,7 +273,6 @@ class ServerRequestHandler: event_data = {"token": token} event_str = json.dumps(event_data) await self.send_sse_event(response, "message", event_str) - print(event_str) await asyncio.sleep(0) @@ -288,7 +287,6 @@ class ServerRequestHandler: generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) tasks.append(generate_task) - #tasks.append(self.generate_text(newprompt, genparams)) try: await asyncio.gather(*tasks) @@ -344,7 +342,7 @@ class ServerRequestHandler: body = await request.content.read() basic_api_flag = False kai_api_flag = False - kai_sse_stream_flag = True + kai_sse_stream_flag = False path = request.path.rstrip('/') print(request) @@ -382,10 +380,10 @@ class ServerRequestHandler: gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) - if not kai_sse_stream_flag: - return web.Response(body=gen) - modelbusy = False + + if not kai_sse_stream_flag: + return web.Response(body=json.dumps(gen).encode()) return web.Response(); return web.Response(status=404) @@ -398,6 +396,11 @@ class ServerRequestHandler: async def start_server(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.addr, self.port)) + sock.listen(5) + self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) self.app.router.add_route('OPTIONS', '/', self.handle_options) @@ -405,7 +408,7 @@ class ServerRequestHandler: runner = web.AppRunner(self.app) await runner.setup() - site = web.TCPSite(runner, self.addr, self.port) + site = web.SockSite(runner, sock) await site.start() # Keep Alive @@ -415,7 +418,11 @@ class ServerRequestHandler: except KeyboardInterrupt: await runner.cleanup() await site.stop() - await exit(1) + await sys.exit(0) + finally: + await runner.cleanup() + await site.stop() + await sys.exit(0) async def run_server(addr, port, embedded_kailite=None): handler = ServerRequestHandler(addr, port, embedded_kailite)