diff --git a/expose.cpp b/expose.cpp index 6388bed88..6e4f656e1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -23,6 +23,10 @@ std::string executable_path = ""; std::string lora_filename = ""; + +bool generation_finished; +std::vector generated_tokens; + extern "C" { @@ -207,4 +211,18 @@ extern "C" { return gpttype_generate(inputs, output); } + + const char* new_token(int idx) { + if (generated_tokens.size() <= idx || idx < 0) return nullptr; + + return generated_tokens[idx].c_str(); + } + + int get_stream_count() { + return generated_tokens.size(); + } + + bool has_finished() { + return generation_finished; + } } diff --git a/expose.h b/expose.h index cb475f141..bb9c5920b 100644 --- a/expose.h +++ b/expose.h @@ -18,6 +18,7 @@ struct load_model_inputs const int clblast_info = 0; const int blasbatchsize = 512; const bool debugmode; + const bool stream_sse; const int forceversion = 0; const int gpulayers = 0; }; @@ -48,3 +49,6 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; + +extern std::vector generated_tokens; +extern bool generation_finished; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index a9ece44f2..3c9862980 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -63,6 +63,7 @@ static bool useSmartContext = false; static bool unbanTokens = false; static int blasbatchsize = 512; static bool debugmode = false; +static bool stream_sse = true; static std::string modelname; static std::vector last_n_tokens; static std::vector current_context_tokens; @@ -735,6 +736,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_batch = n_batch; params.n_threads = n_threads; + generation_finished = false; // Set current generation status + generated_tokens.clear(); // New Generation, new tokens + if (params.repeat_last_n < 1) { params.repeat_last_n = 1; @@ -1038,6 +1042,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; + generation_finished = true; return output; } } @@ -1147,7 +1152,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o for (auto id : embd) { - concat_output += FileFormatTokenizeID(id,file_format); + std::string tokenizedstr = FileFormatTokenizeID(id, file_format); + + if (stream_sse) + { + generated_tokens.push_back(tokenizedstr); + } + concat_output += tokenizedstr; } if (startedsampling) @@ -1214,6 +1225,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; + generation_finished = true; snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index cf8847bc8..981513987 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,7 +5,8 @@ import ctypes import os import argparse -import json, http.server, threading, socket, sys, time +import json, sys, http.server, time, asyncio, socket, threading +from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 @@ -134,6 +135,10 @@ def init_library(): handle.load_model.restype = ctypes.c_bool handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs + handle.new_token.restype = ctypes.c_char_p + handle.new_token.argtypes = [ctypes.c_int] + handle.get_stream_count.restype = ctypes.c_int + handle.has_finished.restype = ctypes.c_bool def load_model(model_filename): inputs = load_model_inputs() @@ -183,7 +188,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 inputs.seed = seed - for n in range(0,stop_token_max): + for n in range(stop_token_max): if not stop_sequence or n >= len(stop_sequence): inputs.stop_sequence[n] = "".encode("UTF-8") else: @@ -224,8 +229,106 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): def __call__(self, *args, **kwargs): super().__init__(*args, **kwargs) + async def generate_text(self, newprompt, genparams, basic_api_flag): + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor() + + def run_blocking(): + if basic_api_flag: + return generate( + prompt=newprompt, + max_length=genparams.get('max', 50), + temperature=genparams.get('temperature', 0.8), + top_k=int(genparams.get('top_k', 120)), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + else: + return generate(prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + + recvtxt = await loop.run_in_executor(executor, run_blocking) + + utfprint("\nOutput: " + recvtxt) + + res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]} + + try: + return res + except Exception as e: + print(f"Generate: Error while generating: {e}") + + + async def send_sse_event(self, event, data): + self.wfile.write(f'event: {event}\n'.encode()) + self.wfile.write(f'data: {data}\n\n'.encode()) + + + async def handle_sse_stream(self): + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.end_headers() + + current_token = 0; + + while not handle.has_finished(): + if current_token < handle.get_stream_count(): + token = handle.new_token(current_token) + + if token is None: # Token isnt ready yet, received nullpointer + continue + + current_token += 1 + + tokenStr = ctypes.string_at(token).decode('utf-8') + event_data = {"token": tokenStr} + event_str = json.dumps(event_data) + await self.send_sse_event("message", event_str) + + await asyncio.sleep(0) + + # Implement connection closing here + + + async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag): + tasks = [] + + if stream_flag: + tasks.append(self.handle_sse_stream()) + + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) + tasks.append(generate_task) + + try: + await asyncio.gather(*tasks) + generate_result = generate_task.result() + return generate_result + except Exception as e: + print(e) + + def do_GET(self): - global maxctx, maxlen, friendlymodelname, KcppVersion + global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock self.path = self.path.rstrip('/') response_body = None @@ -286,8 +389,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): body = self.rfile.read(content_length) basic_api_flag = False kai_api_flag = False + kai_sse_stream_flag = False self.path = self.path.rstrip('/') + if modelbusy: self.send_response(503) self.end_headers() @@ -303,72 +408,44 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True + if self.path.endswith('/api/extra/generate/stream'): + kai_api_flag = True + kai_sse_stream_flag = True + if basic_api_flag or kai_api_flag: genparams = None try: genparams = json.loads(body) except ValueError as e: - self.send_response(503) - self.end_headers() - return + return self.send_response(503) + utfprint("\nInput: " + json.dumps(genparams)) modelbusy = True + if kai_api_flag: fullprompt = genparams.get('prompt', "") else: fullprompt = genparams.get('text', "") newprompt = fullprompt - recvtxt = "" - res = {} - if kai_api_flag: - recvtxt = generate( - prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"results": [{"text": recvtxt}]} - else: - recvtxt = generate( - prompt=newprompt, - max_length=genparams.get('max', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"data": {"seqs":[recvtxt]}} + gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) try: self.send_response(200) self.end_headers() - self.wfile.write(json.dumps(res).encode()) + self.wfile.write(json.dumps(gen).encode()) except: print("Generate: The response could not be sent, maybe connection was terminated?") + modelbusy = False + return + self.send_response(404) self.end_headers() + def do_OPTIONS(self): self.send_response(200) self.end_headers() @@ -382,10 +459,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Headers', '*') if "/api" in self.path: + if self.path.endswith("/stream"): + self.send_header('Content-type', 'text/event-stream') self.send_header('Content-type', 'application/json') else: self.send_header('Content-type', 'text/html') - return super(ServerRequestHandler, self).end_headers() @@ -500,7 +578,6 @@ def show_gui(): unbantokens = tk.IntVar() highpriority = tk.IntVar() disablemmap = tk.IntVar() - frameD = tk.Frame(root) tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) @@ -688,7 +765,7 @@ def main(args): except: print("--launch was set, but could not launch web browser automatically.") print(f"Please connect to custom endpoint at {epurl}") - RunServerMultiThreaded(args.host, args.port, embedded_kailite) + asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite)) if __name__ == '__main__': print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually