diff --git a/expose.cpp b/expose.cpp index 283438a4b..8142e85cd 100644 --- a/expose.cpp +++ b/expose.cpp @@ -210,7 +210,6 @@ extern "C" generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { - finished_stream = false; return gpttype_generate(inputs, output); } @@ -230,6 +229,12 @@ extern "C" bool has_finished() { return finished_stream; } + + + // TODO: dont duplicate code + void bind_set_stream_finished(bool status) { + finished_stream = status; + } } void receive_current_token(std::string token) { @@ -237,6 +242,6 @@ void receive_current_token(std::string token) { new_token_available = true; } -void set_stream_finished() { - finished_stream = true; +void set_stream_finished(bool status) { + finished_stream = status; } diff --git a/expose.h b/expose.h index dac7fffe0..fb360b48d 100644 --- a/expose.h +++ b/expose.h @@ -51,4 +51,4 @@ extern std::string executable_path; extern std::string lora_filename; extern void receive_current_token(std::string token); -extern void set_stream_finished(); +extern void set_stream_finished(bool status = true); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index b49fd95f7..cdc0b5cf4 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1041,7 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; - set_stream_finished(); + set_stream_finished(true); return output; } } @@ -1224,7 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; - set_stream_finished(); + set_stream_finished(true); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index 553460057..772023221 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,8 +5,9 @@ import ctypes import os import argparse -import json, http.server, threading, socket, sys, time, asyncio +import json, sys, time, asyncio from aiohttp import web +from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 @@ -185,7 +186,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 inputs.seed = seed - for n in range(0,stop_token_max): + for n in range(stop_token_max): if not stop_sequence or n >= len(stop_sequence): inputs.stop_sequence[n] = "".encode("UTF-8") else: @@ -224,79 +225,96 @@ class ServerRequestHandler: self.port = port self.embedded_kailite = embedded_kailite + + async def generate_text(self, newprompt, genparams): + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor() + + def run_blocking(): + # Reset finished status before generating + handle.bind_set_stream_finished(False) + + return generate(prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + + recvtxt = await loop.run_in_executor(executor, run_blocking) + + res = {"results": [{"text": recvtxt}]} + + try: + return res + except: + print("Generate: Error while generating") + + async def send_sse_event(self, response, event, data): await response.write(f'event: {event}\n'.encode()) await response.write(f'data: {data}\n\n'.encode()) - async def handle_sse_stream(self, request): response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) await response.prepare(request) - stream_finished = False - - while True: - if handle.has_finished(): - stream_finished = True + while not handle.has_finished(): if not handle.is_locked(): token = ctypes.string_at(handle.new_token()).decode('utf-8') - event_data = {"finished": stream_finished, "token": token} - event_str = f"data: {json.dumps(event_data)}" + event_data = {"token": token} + event_str = json.dumps(event_data) await self.send_sse_event(response, "message", event_str) - print(token) - print(event_data) - if stream_finished: - break + print(event_str) - async def generate_text(self, newprompt, genparams): - recvtxt = generate( - prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 120), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - res = {"results": [{"text": recvtxt}]} + await asyncio.sleep(0) - try: - return web.json_response(res) - except: - print("Generate: The response could not be sent, maybe the connection was terminated?") + await response.write_eof() + await response.force_close() async def handle_request(self, request, genparams, newprompt, stream_flag): + tasks = [] + if stream_flag: - self.handle_sse_stream(request) - # RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER - self.generate_text(newprompt, genparams) + tasks.append(self.handle_sse_stream(request,)) + + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) + tasks.append(generate_task) + #tasks.append(self.generate_text(newprompt, genparams)) + + try: + await asyncio.gather(*tasks) + if not stream_flag: + generate_result = generate_task.result() + return generate_result + except Exception as e: + print(e) async def handle_get(self, request): global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock path = request.path.rstrip('/') - if path in ["", "/?"] or path.startswith(('/?','?')): + if path in ["", "/?"] or path.startswith(('/?', '?')): if args.stream and not "streaming=1" in path: - path = path.replace("streaming=0","") - if path.startswith(('/?','?')): + path = path.replace("streaming=0", "") + if path.startswith(('/?', '?')): path += "&streaming=1" else: path = path + "?streaming=1" - raise web.HTTPFound(path) if self.embedded_kailite is None: - return web.Response( - text="Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect." - ) + return web.Response(body=f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.".encode()) else: - return web.Response(body=self.embedded_kailite) + return web.Response(body=self.embedded_kailite, content_type='text/html') elif path.endswith(('/api/v1/model', '/api/latest/model')): return web.json_response({'result': friendlymodelname}) @@ -326,7 +344,7 @@ class ServerRequestHandler: body = await request.content.read() basic_api_flag = False kai_api_flag = False - kai_stream_flag = True + kai_sse_stream_flag = True path = request.path.rstrip('/') print(request) @@ -344,7 +362,7 @@ class ServerRequestHandler: if path.endswith('/api/v1/generate/stream'): kai_api_flag = True - kai_stream_flag = True + kai_sse_stream_flag = True if basic_api_flag or kai_api_flag: genparams = None @@ -362,10 +380,13 @@ class ServerRequestHandler: fullprompt = genparams.get('text', "") newprompt = fullprompt - await self.handle_request(request, genparams, newprompt, kai_stream_flag) + gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) + + if not kai_sse_stream_flag: + return web.Response(body=gen) modelbusy = False - return web.Response() + return web.Response(); return web.Response(status=404) @@ -393,6 +414,8 @@ class ServerRequestHandler: await asyncio.sleep(3600) except KeyboardInterrupt: await runner.cleanup() + await site.stop() + await exit(1) async def run_server(addr, port, embedded_kailite=None): handler = ServerRequestHandler(addr, port, embedded_kailite)