From 97971291e9e5d05c428d4c0f0cb8f956e36a63c5 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Wed, 7 Jun 2023 00:48:00 +0200 Subject: [PATCH] draft: token streaming --- expose.cpp | 32 +++++ expose.h | 4 + gpttype_adapter.cpp | 11 +- koboldcpp.py | 306 ++++++++++++++++++++------------------------ 4 files changed, 183 insertions(+), 170 deletions(-) diff --git a/expose.cpp b/expose.cpp index 6388bed88..283438a4b 100644 --- a/expose.cpp +++ b/expose.cpp @@ -23,6 +23,11 @@ std::string executable_path = ""; std::string lora_filename = ""; + +static std::string current_token = ""; +static bool new_token_available = false; +static bool finished_stream = false; + extern "C" { @@ -205,6 +210,33 @@ extern "C" generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { + finished_stream = false; return gpttype_generate(inputs, output); } + + + const char* new_token() { + if (new_token_available) { + new_token_available = false; + return current_token.c_str(); + } + return nullptr; + } + + bool is_locked() { + return !new_token_available; + } + + bool has_finished() { + return finished_stream; + } +} + +void receive_current_token(std::string token) { + current_token = token; + new_token_available = true; +} + +void set_stream_finished() { + finished_stream = true; } diff --git a/expose.h b/expose.h index cb475f141..dac7fffe0 100644 --- a/expose.h +++ b/expose.h @@ -18,6 +18,7 @@ struct load_model_inputs const int clblast_info = 0; const int blasbatchsize = 512; const bool debugmode; + const bool stream_sse; const int forceversion = 0; const int gpulayers = 0; }; @@ -48,3 +49,6 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; + +extern void receive_current_token(std::string token); +extern void set_stream_finished(); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d7c334c00..b49fd95f7 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -63,6 +63,7 @@ static bool useSmartContext = false; static bool unbanTokens = false; static int blasbatchsize = 512; static bool debugmode = false; +static bool stream_sse = true; static std::string modelname; static std::vector last_n_tokens; static std::vector current_context_tokens; @@ -1040,6 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; + set_stream_finished(); return output; } } @@ -1149,7 +1151,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o for (auto id : embd) { - concat_output += FileFormatTokenizeID(id,file_format); + std::string tokenizedstr = FileFormatTokenizeID(id, file_format); + + if (stream_sse) + { + receive_current_token(tokenizedstr); + } + concat_output += tokenizedstr; } if (startedsampling) @@ -1216,6 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; + set_stream_finished(); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index cf8847bc8..553460057 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,7 +5,8 @@ import ctypes import os import argparse -import json, http.server, threading, socket, sys, time +import json, http.server, threading, socket, sys, time, asyncio +from aiohttp import web stop_token_max = 10 @@ -134,6 +135,7 @@ def init_library(): handle.load_model.restype = ctypes.c_bool handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs + handle.new_token.restype = ctypes.c_char_p def load_model(model_filename): inputs = load_model_inputs() @@ -212,105 +214,145 @@ modelbusy = False defaultport = 5001 KcppVersion = "1.29" -class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): +class ServerRequestHandler: sys_version = "" server_version = "ConcedoLlamaForKoboldServer" + app = web.Application() def __init__(self, addr, port, embedded_kailite): self.addr = addr self.port = port self.embedded_kailite = embedded_kailite - def __call__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + async def send_sse_event(self, response, event, data): + await response.write(f'event: {event}\n'.encode()) + await response.write(f'data: {data}\n\n'.encode()) - def do_GET(self): - global maxctx, maxlen, friendlymodelname, KcppVersion - self.path = self.path.rstrip('/') - response_body = None - if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / - if args.stream and not "streaming=1" in self.path: - self.path = self.path.replace("streaming=0","") - if self.path.startswith(('/?','?')): - self.path += "&streaming=1" + async def handle_sse_stream(self, request): + response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) + await response.prepare(request) + + stream_finished = False + + while True: + if handle.has_finished(): + stream_finished = True + if not handle.is_locked(): + token = ctypes.string_at(handle.new_token()).decode('utf-8') + event_data = {"finished": stream_finished, "token": token} + event_str = f"data: {json.dumps(event_data)}" + await self.send_sse_event(response, "message", event_str) + print(token) + print(event_data) + if stream_finished: + break + + async def generate_text(self, newprompt, genparams): + recvtxt = generate( + prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + res = {"results": [{"text": recvtxt}]} + + try: + return web.json_response(res) + except: + print("Generate: The response could not be sent, maybe the connection was terminated?") + + async def handle_request(self, request, genparams, newprompt, stream_flag): + if stream_flag: + self.handle_sse_stream(request) + # RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER + self.generate_text(newprompt, genparams) + + + async def handle_get(self, request): + global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock + path = request.path.rstrip('/') + + if path in ["", "/?"] or path.startswith(('/?','?')): + if args.stream and not "streaming=1" in path: + path = path.replace("streaming=0","") + if path.startswith(('/?','?')): + path += "&streaming=1" else: - self.path = self.path + "?streaming=1" - self.send_response(302) - self.send_header("Location", self.path) - self.end_headers() - print("Force redirect to streaming mode, as --stream is set.") - return None + path = path + "?streaming=1" + raise web.HTTPFound(path) if self.embedded_kailite is None: - response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode() + return web.Response( + text="Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect." + ) else: - response_body = self.embedded_kailite + return web.Response(body=self.embedded_kailite) - elif self.path.endswith(('/api/v1/model', '/api/latest/model')): - response_body = (json.dumps({'result': friendlymodelname }).encode()) + elif path.endswith(('/api/v1/model', '/api/latest/model')): + return web.json_response({'result': friendlymodelname}) - elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): - response_body = (json.dumps({"value": maxlen}).encode()) + elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): + return web.json_response({"value": maxlen}) - elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): - response_body = (json.dumps({"value": maxctx}).encode()) + elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): + return web.json_response({"value": maxctx}) - elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): - response_body = (json.dumps({"value":""}).encode()) + elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): + return web.json_response({"value": ""}) - elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): - response_body = (json.dumps({"values": []}).encode()) + elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): + return web.json_response({"values": []}) - elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): - response_body = (json.dumps({"result":"1.2.2"}).encode()) + elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')): + return web.json_response({"result": "1.2.2"}) - elif self.path.endswith(('/api/extra/version')): - response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) + elif path.endswith(('/api/extra/version')): + return web.json_response({"result": "KoboldCpp", "version": KcppVersion}) - if response_body is None: - self.send_response(404) - self.end_headers() - rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' - self.wfile.write(rp.encode()) - else: - self.send_response(200) - self.send_header('Content-Length', str(len(response_body))) - self.end_headers() - self.wfile.write(response_body) - return + return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.") - def do_POST(self): + async def handle_post(self, request): global modelbusy - content_length = int(self.headers['Content-Length']) - body = self.rfile.read(content_length) + body = await request.content.read() basic_api_flag = False kai_api_flag = False - self.path = self.path.rstrip('/') + kai_stream_flag = True + path = request.path.rstrip('/') + print(request) if modelbusy: - self.send_response(503) - self.end_headers() - self.wfile.write(json.dumps({"detail": { - "msg": "Server is busy; please try again later.", - "type": "service_unavailable", - }}).encode()) - return + return web.json_response( + {"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}}, + status=503, + ) - if self.path.endswith('/request'): + if path.endswith('/request'): basic_api_flag = True - if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): + if path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True + if path.endswith('/api/v1/generate/stream'): + kai_api_flag = True + kai_stream_flag = True + if basic_api_flag or kai_api_flag: genparams = None try: genparams = json.loads(body) except ValueError as e: - self.send_response(503) - self.end_headers() - return + return web.Response(status=503) + utfprint("\nInput: " + json.dumps(genparams)) modelbusy = True @@ -320,115 +362,41 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): fullprompt = genparams.get('text', "") newprompt = fullprompt - recvtxt = "" - res = {} - if kai_api_flag: - recvtxt = generate( - prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"results": [{"text": recvtxt}]} - else: - recvtxt = generate( - prompt=newprompt, - max_length=genparams.get('max', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"data": {"seqs":[recvtxt]}} + await self.handle_request(request, genparams, newprompt, kai_stream_flag) - try: - self.send_response(200) - self.end_headers() - self.wfile.write(json.dumps(res).encode()) - except: - print("Generate: The response could not be sent, maybe connection was terminated?") modelbusy = False - return - self.send_response(404) - self.end_headers() + return web.Response() - def do_OPTIONS(self): - self.send_response(200) - self.end_headers() + return web.Response(status=404) - def do_HEAD(self): - self.send_response(200) - self.end_headers() + async def handle_options(self): + return web.Response() - def end_headers(self): - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', '*') - self.send_header('Access-Control-Allow-Headers', '*') - if "/api" in self.path: - self.send_header('Content-type', 'application/json') - else: - self.send_header('Content-type', 'text/html') + async def handle_head(self): + return web.Response() - return super(ServerRequestHandler, self).end_headers() + async def start_server(self): + self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) + self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) + self.app.router.add_route('OPTIONS', '/', self.handle_options) + self.app.router.add_route('HEAD', '/', self.handle_head) -def RunServerMultiThreaded(addr, port, embedded_kailite = None): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((addr, port)) - sock.listen(5) + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, self.addr, self.port) + await site.start() - class Thread(threading.Thread): - def __init__(self, i): - threading.Thread.__init__(self) - self.i = i - self.daemon = True - self.start() - - def run(self): - handler = ServerRequestHandler(addr, port, embedded_kailite) - with http.server.HTTPServer((addr, port), handler, False) as self.httpd: - try: - self.httpd.socket = sock - self.httpd.server_bind = self.server_close = lambda self: None - self.httpd.serve_forever() - except (KeyboardInterrupt,SystemExit): - self.httpd.server_close() - sys.exit(0) - finally: - self.httpd.server_close() - sys.exit(0) - def stop(self): - self.httpd.server_close() - - numThreads = 6 - threadArr = [] - for i in range(numThreads): - threadArr.append(Thread(i)) - while 1: + # Keep Alive try: - time.sleep(10) + while True: + await asyncio.sleep(3600) except KeyboardInterrupt: - for i in range(numThreads): - threadArr[i].stop() - sys.exit(0) + await runner.cleanup() + +async def run_server(addr, port, embedded_kailite=None): + handler = ServerRequestHandler(addr, port, embedded_kailite) + await handler.start_server() def show_gui(): @@ -500,15 +468,15 @@ def show_gui(): unbantokens = tk.IntVar() highpriority = tk.IntVar() disablemmap = tk.IntVar() + frm3 = tk.Frame(root) + tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) + tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) + tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) + tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) + tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) + tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - frameD = tk.Frame(root) - tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) - tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) - tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) - tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) - tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) - tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - frameD.grid(row=5,column=0,pady=4) + frm3.grid(row=5,column=0,pady=4) # Create button, it will change label text tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) @@ -688,7 +656,7 @@ def main(args): except: print("--launch was set, but could not launch web browser automatically.") print(f"Please connect to custom endpoint at {epurl}") - RunServerMultiThreaded(args.host, args.port, embedded_kailite) + asyncio.run(run_server(args.host, args.port, embedded_kailite)) if __name__ == '__main__': print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually