From e6231c30553b0720ffdda04106625e3a56b32ae5 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Fri, 9 Jun 2023 12:17:55 +0200 Subject: [PATCH] back to http.server, improved implementation --- expose.cpp | 35 ++---- expose.h | 4 +- gpttype_adapter.cpp | 9 +- koboldcpp.py | 298 ++++++++++++++++++++++++++------------------ 4 files changed, 196 insertions(+), 150 deletions(-) diff --git a/expose.cpp b/expose.cpp index 8142e85cd..6e4f656e1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -24,9 +24,8 @@ std::string executable_path = ""; std::string lora_filename = ""; -static std::string current_token = ""; -static bool new_token_available = false; -static bool finished_stream = false; +bool generation_finished; +std::vector generated_tokens; extern "C" { @@ -213,35 +212,17 @@ extern "C" return gpttype_generate(inputs, output); } + const char* new_token(int idx) { + if (generated_tokens.size() <= idx || idx < 0) return nullptr; - const char* new_token() { - if (new_token_available) { - new_token_available = false; - return current_token.c_str(); - } - return nullptr; + return generated_tokens[idx].c_str(); } - bool is_locked() { - return !new_token_available; + int get_stream_count() { + return generated_tokens.size(); } bool has_finished() { - return finished_stream; - } - - - // TODO: dont duplicate code - void bind_set_stream_finished(bool status) { - finished_stream = status; + return generation_finished; } } - -void receive_current_token(std::string token) { - current_token = token; - new_token_available = true; -} - -void set_stream_finished(bool status) { - finished_stream = status; -} diff --git a/expose.h b/expose.h index fb360b48d..bb9c5920b 100644 --- a/expose.h +++ b/expose.h @@ -50,5 +50,5 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; -extern void receive_current_token(std::string token); -extern void set_stream_finished(bool status = true); +extern std::vector generated_tokens; +extern bool generation_finished; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index cdc0b5cf4..1c75a8b7c 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -736,6 +736,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_batch = n_batch; params.n_threads = n_threads; + generation_finished = false; // Set current generation status + generated_tokens.clear(); // New Generation, new tokens + if (params.repeat_last_n < 1) { params.repeat_last_n = 1; @@ -1041,7 +1044,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; - set_stream_finished(true); + generation_finished = true; return output; } } @@ -1155,7 +1158,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (stream_sse) { - receive_current_token(tokenizedstr); + generated_tokens.push_back(tokenizedstr); } concat_output += tokenizedstr; } @@ -1224,7 +1227,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; - set_stream_finished(true); + generation_finished = true; snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index 7320b2ce0..e7ea0748e 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,8 +5,7 @@ import ctypes import os import argparse -import json, sys, time, asyncio, socket -from aiohttp import web +import json, sys, http.server, time, asyncio, socket, threading from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 @@ -137,6 +136,9 @@ def init_library(): handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs handle.new_token.restype = ctypes.c_char_p + handle.new_token.argtypes = [ctypes.c_int] + handle.get_stream_count.restype = ctypes.c_int + handle.has_finished.restype = ctypes.c_bool def load_model(model_filename): inputs = load_model_inputs() @@ -215,25 +217,23 @@ modelbusy = False defaultport = 5001 KcppVersion = "1.29" -class ServerRequestHandler: +class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" server_version = "ConcedoLlamaForKoboldServer" - app = web.Application() def __init__(self, addr, port, embedded_kailite): self.addr = addr self.port = port self.embedded_kailite = embedded_kailite + def __call__(self, *args, **kwargs): + super().__init__(*args, **kwargs) async def generate_text(self, newprompt, genparams, basic_api_flag): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() def run_blocking(): - # Reset finished status before generating - handle.bind_set_stream_finished(False) - if basic_api_flag: return generate( prompt=newprompt, @@ -249,21 +249,21 @@ class ServerRequestHandler: seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []) ) - - return generate(prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 120), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) + else: + return generate(prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) recvtxt = await loop.run_in_executor(executor, run_blocking) @@ -277,104 +277,139 @@ class ServerRequestHandler: print(f"Generate: Error while generating {e}") - async def send_sse_event(self, response, event, data): - await response.write(f'event: {event}\n'.encode()) - await response.write(f'data: {data}\n\n'.encode()) + async def send_sse_event(self, event, data): + self.wfile.write(f'event: {event}\n'.encode()) + self.wfile.write(f'data: {data}\n\n'.encode()) - async def handle_sse_stream(self, request): - response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) - await response.prepare(request) + + async def handle_sse_stream(self): + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.end_headers() + + current_token = 0; while not handle.has_finished(): - if not handle.is_locked(): - token = ctypes.string_at(handle.new_token()).decode('utf-8') - event_data = {"token": token} + if current_token < handle.get_stream_count(): + token = handle.new_token(current_token) + + if token is None: # Token isnt ready yet, received nullpointer + continue + + current_token += 1 + + tokenStr = ctypes.string_at(token).decode('utf-8') + event_data = {"token": tokenStr} event_str = json.dumps(event_data) - await self.send_sse_event(response, "message", event_str) + await self.send_sse_event("message", event_str) await asyncio.sleep(0) - await response.write_eof() - await response.force_close() + await self.wfile.close() - async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag): + + async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag): tasks = [] if stream_flag: - tasks.append(self.handle_sse_stream(request,)) + tasks.append(self.handle_sse_stream()) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) tasks.append(generate_task) try: await asyncio.gather(*tasks) + print("done") generate_result = generate_task.result() return generate_result except Exception as e: print(e) - async def handle_get(self, request): + def do_GET(self): global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock - path = request.path.rstrip('/') + self.path = self.path.rstrip('/') + response_body = None - if path in ["", "/?"] or path.startswith(('/?', '?')): - if args.stream and not "streaming=1" in path: - path = path.replace("streaming=0", "") - if path.startswith(('/?', '?')): - path += "&streaming=1" + if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / + if args.stream and not "streaming=1" in self.path: + self.path = self.path.replace("streaming=0","") + if self.path.startswith(('/?','?')): + self.path += "&streaming=1" else: - path = path + "?streaming=1" + self.path = self.path + "?streaming=1" + self.send_response(302) + self.send_header("Location", self.path) + self.end_headers() + print("Force redirect to streaming mode, as --stream is set.") + return None if self.embedded_kailite is None: - return web.Response(body=f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.".encode()) + response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode() else: - return web.Response(body=self.embedded_kailite, content_type='text/html') + response_body = self.embedded_kailite - elif path.endswith(('/api/v1/model', '/api/latest/model')): - return web.json_response({'result': friendlymodelname}) + elif self.path.endswith(('/api/v1/model', '/api/latest/model')): + response_body = (json.dumps({'result': friendlymodelname }).encode()) - elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): - return web.json_response({"value": maxlen}) + elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): + response_body = (json.dumps({"value": maxlen}).encode()) - elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): - return web.json_response({"value": maxctx}) + elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): + response_body = (json.dumps({"value": maxctx}).encode()) - elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): - return web.json_response({"value": ""}) + elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): + response_body = (json.dumps({"value":""}).encode()) - elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): - return web.json_response({"values": []}) + elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): + response_body = (json.dumps({"values": []}).encode()) - elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')): - return web.json_response({"result": "1.2.2"}) + elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): + response_body = (json.dumps({"result":"1.2.2"}).encode()) - elif path.endswith(('/api/extra/version')): - return web.json_response({"result": "KoboldCpp", "version": KcppVersion}) + elif self.path.endswith(('/api/extra/version')): + response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) - return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.") + if response_body is None: + self.send_response(404) + self.end_headers() + rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' + self.wfile.write(rp.encode()) + else: + self.send_response(200) + self.send_header('Content-Length', str(len(response_body))) + self.end_headers() + self.wfile.write(response_body) + return - async def handle_post(self, request): + def do_POST(self): global modelbusy - body = await request.content.read() + content_length = int(self.headers['Content-Length']) + body = self.rfile.read(content_length) basic_api_flag = False kai_api_flag = False kai_sse_stream_flag = False - path = request.path.rstrip('/') + self.path = self.path.rstrip('/') + if modelbusy: - return web.json_response( - {"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}}, - status=503, - ) + self.send_response(503) + self.end_headers() + self.wfile.write(json.dumps({"detail": { + "msg": "Server is busy; please try again later.", + "type": "service_unavailable", + }}).encode()) + return - if path.endswith('/request'): + if self.path.endswith('/request'): basic_api_flag = True - if path.endswith(('/api/v1/generate', '/api/latest/generate')): + if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True - if path.endswith('/api/extra/generate/stream'): + if self.path.endswith('/api/extra/generate/stream'): kai_api_flag = True kai_sse_stream_flag = True @@ -383,66 +418,94 @@ class ServerRequestHandler: try: genparams = json.loads(body) except ValueError as e: - return web.Response(status=503) + return self.send_response(503) utfprint("\nInput: " + json.dumps(genparams)) modelbusy = True + if kai_api_flag: fullprompt = genparams.get('prompt', "") else: fullprompt = genparams.get('text', "") newprompt = fullprompt - gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag) + gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) + + try: + self.wfile.write(json.dumps(gen).encode()) + except: + print("Generate: The response could not be sent, maybe connection was terminated?") modelbusy = False - if not kai_sse_stream_flag: - return web.Response(body=json.dumps(gen).encode()) - return web.Response(); + return - return web.Response(status=404) + self.send_response(404) + self.end_headers() - async def handle_options(self): - return web.Response() - async def handle_head(self): - return web.Response() + def do_OPTIONS(self): + self.send_response(200) + self.end_headers() - async def start_server(self): + def do_HEAD(self): + self.send_response(200) + self.end_headers() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((self.addr, self.port)) - sock.listen(5) + def end_headers(self): + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', '*') + self.send_header('Access-Control-Allow-Headers', '*') + if "/api" in self.path: + if self.path.endswith("/stream"): + self.send_header('Content-type', 'text/event-stream') + self.send_header('Content-type', 'application/json') + else: + self.send_header('Content-type', 'text/html') + return super(ServerRequestHandler, self).end_headers() - self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) - self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) - self.app.router.add_route('OPTIONS', '/', self.handle_options) - self.app.router.add_route('HEAD', '/', self.handle_head) - runner = web.AppRunner(self.app) - await runner.setup() - site = web.SockSite(runner, sock) - await site.start() +def RunServerMultiThreaded(addr, port, embedded_kailite = None): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((addr, port)) + sock.listen(5) - # Keep Alive + class Thread(threading.Thread): + def __init__(self, i): + threading.Thread.__init__(self) + self.i = i + self.daemon = True + self.start() + + def run(self): + handler = ServerRequestHandler(addr, port, embedded_kailite) + with http.server.HTTPServer((addr, port), handler, False) as self.httpd: + try: + self.httpd.socket = sock + self.httpd.server_bind = self.server_close = lambda self: None + self.httpd.serve_forever() + except (KeyboardInterrupt,SystemExit): + self.httpd.server_close() + sys.exit(0) + finally: + self.httpd.server_close() + sys.exit(0) + def stop(self): + self.httpd.server_close() + + numThreads = 6 + threadArr = [] + for i in range(numThreads): + threadArr.append(Thread(i)) + while 1: try: - while True: - await asyncio.sleep(3600) + time.sleep(10) except KeyboardInterrupt: - await runner.cleanup() - await site.stop() - await sys.exit(0) - finally: - await runner.cleanup() - await site.stop() - await sys.exit(0) - -async def run_server(addr, port, embedded_kailite=None): - handler = ServerRequestHandler(addr, port, embedded_kailite) - await handler.start_server() + for i in range(numThreads): + threadArr[i].stop() + sys.exit(0) def show_gui(): @@ -514,15 +577,14 @@ def show_gui(): unbantokens = tk.IntVar() highpriority = tk.IntVar() disablemmap = tk.IntVar() - frm3 = tk.Frame(root) - tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) - tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) - tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) - tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) - tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) - tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - - frm3.grid(row=5,column=0,pady=4) + frameD = tk.Frame(root) + tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) + tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) + tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) + tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) + tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) + tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) + frameD.grid(row=5,column=0,pady=4) # Create button, it will change label text tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) @@ -702,7 +764,7 @@ def main(args): except: print("--launch was set, but could not launch web browser automatically.") print(f"Please connect to custom endpoint at {epurl}") - asyncio.run(run_server(args.host, args.port, embedded_kailite)) + asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite)) if __name__ == '__main__': print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually