fix legacy streaming

This commit is contained in:
SammCheese 2023-06-08 15:21:00 +02:00
parent 9a8da35ec4
commit b4e9e185d3
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1

View file

@ -5,7 +5,7 @@
import ctypes
import os
import argparse
import json, sys, time, asyncio
import json, sys, time, asyncio, socket
from aiohttp import web
from concurrent.futures import ThreadPoolExecutor
@ -255,8 +255,8 @@ class ServerRequestHandler:
try:
return res
except:
print("Generate: Error while generating")
except Exception as e:
print(f"Generate: Error while generating {e}")
async def send_sse_event(self, response, event, data):
@ -273,7 +273,6 @@ class ServerRequestHandler:
event_data = {"token": token}
event_str = json.dumps(event_data)
await self.send_sse_event(response, "message", event_str)
print(event_str)
await asyncio.sleep(0)
@ -288,7 +287,6 @@ class ServerRequestHandler:
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
tasks.append(generate_task)
#tasks.append(self.generate_text(newprompt, genparams))
try:
await asyncio.gather(*tasks)
@ -344,7 +342,7 @@ class ServerRequestHandler:
body = await request.content.read()
basic_api_flag = False
kai_api_flag = False
kai_sse_stream_flag = True
kai_sse_stream_flag = False
path = request.path.rstrip('/')
print(request)
@ -382,10 +380,10 @@ class ServerRequestHandler:
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
if not kai_sse_stream_flag:
return web.Response(body=gen)
modelbusy = False
if not kai_sse_stream_flag:
return web.Response(body=json.dumps(gen).encode())
return web.Response();
return web.Response(status=404)
@ -398,6 +396,11 @@ class ServerRequestHandler:
async def start_server(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.addr, self.port))
sock.listen(5)
self.app.router.add_route('GET', '/{tail:.*}', self.handle_get)
self.app.router.add_route('POST', '/{tail:.*}', self.handle_post)
self.app.router.add_route('OPTIONS', '/', self.handle_options)
@ -405,7 +408,7 @@ class ServerRequestHandler:
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, self.addr, self.port)
site = web.SockSite(runner, sock)
await site.start()
# Keep Alive
@ -415,7 +418,11 @@ class ServerRequestHandler:
except KeyboardInterrupt:
await runner.cleanup()
await site.stop()
await exit(1)
await sys.exit(0)
finally:
await runner.cleanup()
await site.stop()
await sys.exit(0)
async def run_server(addr, port, embedded_kailite=None):
handler = ServerRequestHandler(addr, port, embedded_kailite)