working streaming. TODO: fix lite

This commit is contained in:
SammCheese 2023-06-08 06:18:23 +02:00
parent 97971291e9
commit 9a8da35ec4
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1
4 changed files with 84 additions and 56 deletions

View file

@ -210,7 +210,6 @@ extern "C"
generation_outputs generate(const generation_inputs inputs, generation_outputs &output) generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
{ {
finished_stream = false;
return gpttype_generate(inputs, output); return gpttype_generate(inputs, output);
} }
@ -230,6 +229,12 @@ extern "C"
bool has_finished() { bool has_finished() {
return finished_stream; return finished_stream;
} }
// TODO: dont duplicate code
void bind_set_stream_finished(bool status) {
finished_stream = status;
}
} }
void receive_current_token(std::string token) { void receive_current_token(std::string token) {
@ -237,6 +242,6 @@ void receive_current_token(std::string token) {
new_token_available = true; new_token_available = true;
} }
void set_stream_finished() { void set_stream_finished(bool status) {
finished_stream = true; finished_stream = status;
} }

View file

@ -51,4 +51,4 @@ extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern void receive_current_token(std::string token); extern void receive_current_token(std::string token);
extern void set_stream_finished(); extern void set_stream_finished(bool status = true);

View file

@ -1041,7 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "Failed to predict\n");
snprintf(output.text, sizeof(output.text), "%s", ""); snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0; output.status = 0;
set_stream_finished(); set_stream_finished(true);
return output; return output;
} }
} }
@ -1224,7 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2));
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
set_stream_finished(); set_stream_finished(true);
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -5,8 +5,9 @@
import ctypes import ctypes
import os import os
import argparse import argparse
import json, http.server, threading, socket, sys, time, asyncio import json, sys, time, asyncio
from aiohttp import web from aiohttp import web
from concurrent.futures import ThreadPoolExecutor
stop_token_max = 10 stop_token_max = 10
@ -185,7 +186,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
else: else:
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
inputs.seed = seed inputs.seed = seed
for n in range(0,stop_token_max): for n in range(stop_token_max):
if not stop_sequence or n >= len(stop_sequence): if not stop_sequence or n >= len(stop_sequence):
inputs.stop_sequence[n] = "".encode("UTF-8") inputs.stop_sequence[n] = "".encode("UTF-8")
else: else:
@ -224,79 +225,96 @@ class ServerRequestHandler:
self.port = port self.port = port
self.embedded_kailite = embedded_kailite self.embedded_kailite = embedded_kailite
async def generate_text(self, newprompt, genparams):
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()
def run_blocking():
# Reset finished status before generating
handle.bind_set_stream_finished(False)
return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
recvtxt = await loop.run_in_executor(executor, run_blocking)
res = {"results": [{"text": recvtxt}]}
try:
return res
except:
print("Generate: Error while generating")
async def send_sse_event(self, response, event, data): async def send_sse_event(self, response, event, data):
await response.write(f'event: {event}\n'.encode()) await response.write(f'event: {event}\n'.encode())
await response.write(f'data: {data}\n\n'.encode()) await response.write(f'data: {data}\n\n'.encode())
async def handle_sse_stream(self, request): async def handle_sse_stream(self, request):
response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) response = web.StreamResponse(headers={"Content-Type": "text/event-stream"})
await response.prepare(request) await response.prepare(request)
stream_finished = False while not handle.has_finished():
while True:
if handle.has_finished():
stream_finished = True
if not handle.is_locked(): if not handle.is_locked():
token = ctypes.string_at(handle.new_token()).decode('utf-8') token = ctypes.string_at(handle.new_token()).decode('utf-8')
event_data = {"finished": stream_finished, "token": token} event_data = {"token": token}
event_str = f"data: {json.dumps(event_data)}" event_str = json.dumps(event_data)
await self.send_sse_event(response, "message", event_str) await self.send_sse_event(response, "message", event_str)
print(token) print(event_str)
print(event_data)
if stream_finished:
break
async def generate_text(self, newprompt, genparams): await asyncio.sleep(0)
recvtxt = generate(
prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
res = {"results": [{"text": recvtxt}]}
try: await response.write_eof()
return web.json_response(res) await response.force_close()
except:
print("Generate: The response could not be sent, maybe the connection was terminated?")
async def handle_request(self, request, genparams, newprompt, stream_flag): async def handle_request(self, request, genparams, newprompt, stream_flag):
tasks = []
if stream_flag: if stream_flag:
self.handle_sse_stream(request) tasks.append(self.handle_sse_stream(request,))
# RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER
self.generate_text(newprompt, genparams) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
tasks.append(generate_task)
#tasks.append(self.generate_text(newprompt, genparams))
try:
await asyncio.gather(*tasks)
if not stream_flag:
generate_result = generate_task.result()
return generate_result
except Exception as e:
print(e)
async def handle_get(self, request): async def handle_get(self, request):
global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock
path = request.path.rstrip('/') path = request.path.rstrip('/')
if path in ["", "/?"] or path.startswith(('/?','?')): if path in ["", "/?"] or path.startswith(('/?', '?')):
if args.stream and not "streaming=1" in path: if args.stream and not "streaming=1" in path:
path = path.replace("streaming=0","") path = path.replace("streaming=0", "")
if path.startswith(('/?','?')): if path.startswith(('/?', '?')):
path += "&streaming=1" path += "&streaming=1"
else: else:
path = path + "?streaming=1" path = path + "?streaming=1"
raise web.HTTPFound(path)
if self.embedded_kailite is None: if self.embedded_kailite is None:
return web.Response( return web.Response(body=f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.".encode())
text="Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect."
)
else: else:
return web.Response(body=self.embedded_kailite) return web.Response(body=self.embedded_kailite, content_type='text/html')
elif path.endswith(('/api/v1/model', '/api/latest/model')): elif path.endswith(('/api/v1/model', '/api/latest/model')):
return web.json_response({'result': friendlymodelname}) return web.json_response({'result': friendlymodelname})
@ -326,7 +344,7 @@ class ServerRequestHandler:
body = await request.content.read() body = await request.content.read()
basic_api_flag = False basic_api_flag = False
kai_api_flag = False kai_api_flag = False
kai_stream_flag = True kai_sse_stream_flag = True
path = request.path.rstrip('/') path = request.path.rstrip('/')
print(request) print(request)
@ -344,7 +362,7 @@ class ServerRequestHandler:
if path.endswith('/api/v1/generate/stream'): if path.endswith('/api/v1/generate/stream'):
kai_api_flag = True kai_api_flag = True
kai_stream_flag = True kai_sse_stream_flag = True
if basic_api_flag or kai_api_flag: if basic_api_flag or kai_api_flag:
genparams = None genparams = None
@ -362,10 +380,13 @@ class ServerRequestHandler:
fullprompt = genparams.get('text', "") fullprompt = genparams.get('text', "")
newprompt = fullprompt newprompt = fullprompt
await self.handle_request(request, genparams, newprompt, kai_stream_flag) gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
if not kai_sse_stream_flag:
return web.Response(body=gen)
modelbusy = False modelbusy = False
return web.Response() return web.Response();
return web.Response(status=404) return web.Response(status=404)
@ -393,6 +414,8 @@ class ServerRequestHandler:
await asyncio.sleep(3600) await asyncio.sleep(3600)
except KeyboardInterrupt: except KeyboardInterrupt:
await runner.cleanup() await runner.cleanup()
await site.stop()
await exit(1)
async def run_server(addr, port, embedded_kailite=None): async def run_server(addr, port, embedded_kailite=None):
handler = ServerRequestHandler(addr, port, embedded_kailite) handler = ServerRequestHandler(addr, port, embedded_kailite)