fix legacy streaming

This commit is contained in:
SammCheese 2023-06-08 15:21:00 +02:00
parent 9a8da35ec4
commit b4e9e185d3
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1

View file

@ -5,7 +5,7 @@
import ctypes import ctypes
import os import os
import argparse import argparse
import json, sys, time, asyncio import json, sys, time, asyncio, socket
from aiohttp import web from aiohttp import web
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -255,8 +255,8 @@ class ServerRequestHandler:
try: try:
return res return res
except: except Exception as e:
print("Generate: Error while generating") print(f"Generate: Error while generating {e}")
async def send_sse_event(self, response, event, data): async def send_sse_event(self, response, event, data):
@ -273,7 +273,6 @@ class ServerRequestHandler:
event_data = {"token": token} event_data = {"token": token}
event_str = json.dumps(event_data) event_str = json.dumps(event_data)
await self.send_sse_event(response, "message", event_str) await self.send_sse_event(response, "message", event_str)
print(event_str)
await asyncio.sleep(0) await asyncio.sleep(0)
@ -288,7 +287,6 @@ class ServerRequestHandler:
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
tasks.append(generate_task) tasks.append(generate_task)
#tasks.append(self.generate_text(newprompt, genparams))
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -344,7 +342,7 @@ class ServerRequestHandler:
body = await request.content.read() body = await request.content.read()
basic_api_flag = False basic_api_flag = False
kai_api_flag = False kai_api_flag = False
kai_sse_stream_flag = True kai_sse_stream_flag = False
path = request.path.rstrip('/') path = request.path.rstrip('/')
print(request) print(request)
@ -382,10 +380,10 @@ class ServerRequestHandler:
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
if not kai_sse_stream_flag:
return web.Response(body=gen)
modelbusy = False modelbusy = False
if not kai_sse_stream_flag:
return web.Response(body=json.dumps(gen).encode())
return web.Response(); return web.Response();
return web.Response(status=404) return web.Response(status=404)
@ -398,6 +396,11 @@ class ServerRequestHandler:
async def start_server(self): async def start_server(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.addr, self.port))
sock.listen(5)
self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) self.app.router.add_route('GET', '/{tail:.*}', self.handle_get)
self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) self.app.router.add_route('POST', '/{tail:.*}', self.handle_post)
self.app.router.add_route('OPTIONS', '/', self.handle_options) self.app.router.add_route('OPTIONS', '/', self.handle_options)
@ -405,7 +408,7 @@ class ServerRequestHandler:
runner = web.AppRunner(self.app) runner = web.AppRunner(self.app)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, self.addr, self.port) site = web.SockSite(runner, sock)
await site.start() await site.start()
# Keep Alive # Keep Alive
@ -415,7 +418,11 @@ class ServerRequestHandler:
except KeyboardInterrupt: except KeyboardInterrupt:
await runner.cleanup() await runner.cleanup()
await site.stop() await site.stop()
await exit(1) await sys.exit(0)
finally:
await runner.cleanup()
await site.stop()
await sys.exit(0)
async def run_server(addr, port, embedded_kailite=None): async def run_server(addr, port, embedded_kailite=None):
handler = ServerRequestHandler(addr, port, embedded_kailite) handler = ServerRequestHandler(addr, port, embedded_kailite)