diff --git a/koboldcpp.py b/koboldcpp.py index eddbc64ae..9f0470ed2 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -392,17 +392,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pass async def generate_text(self, genparams, api_format, stream_flag): - + global friendlymodelname def run_blocking(): if api_format==1: genparams["prompt"] = genparams.get('text', "") genparams["top_k"] = int(genparams.get('top_k', 120)) - genparams["max_length"]=genparams.get('max', 50) + genparams["max_length"] = genparams.get('max', 80) elif api_format==3: frqp = genparams.get('frequency_penalty', 0.1) scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 - genparams["max_length"] = genparams.get('max_tokens', 50) + genparams["max_length"] = genparams.get('max_tokens', 80) genparams["rep_pen"] = scaled_rep_pen + # openai allows either a string or a list as a stop sequence + if isinstance(genparams.get('stop',[]), list): + genparams["stop_sequence"] = genparams.get('stop', []) + else: + genparams["stop_sequence"] = [genparams.get('stop')] + elif api_format==4: + # translate openai chat completion messages format into one big string. + messages_array = genparams.get('messages', []) + messages_string = "" + for message in messages_array: + if message['role'] == "system": + messages_string+="\n### Instruction:\n" + elif message['role'] == "user": + messages_string+="\n### Instruction:\n" + elif message['role'] == "assistant": + messages_string+="\n### Response:\n" + messages_string+=message['content'] + messages_string += "\n### Response:\n" + genparams["prompt"] = messages_string + frqp = genparams.get('frequency_penalty', 0.1) + scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 + genparams["max_length"] = genparams.get('max_tokens', 80) + genparams["rep_pen"] = scaled_rep_pen + # openai allows either a string or a list as a stop sequence + if isinstance(genparams.get('stop',[]), list): + genparams["stop_sequence"] = genparams.get('stop', []) + else: + genparams["stop_sequence"] = [genparams.get('stop')] return generate( prompt=genparams.get('prompt', ""), @@ -442,8 +470,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if api_format==1: res = {"data": {"seqs":[recvtxt]}} elif api_format==3: - res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp", + res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname, "choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]} + elif api_format==4: + res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname, + "choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]} else: res = {"results": [{"text": recvtxt}]} @@ -453,19 +484,21 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): print(f"Generate: Error while generating: {e}") - async def send_sse_event(self, event, data): - self.wfile.write(f'event: {event}\n'.encode()) + async def send_oai_sse_event(self, data): + self.wfile.write(f'data: {data}\r\n\r\n'.encode()) + + async def send_kai_sse_event(self, data): + self.wfile.write(f'event: message\n'.encode()) self.wfile.write(f'data: {data}\n\n'.encode()) - - async def handle_sse_stream(self): + async def handle_sse_stream(self, api_format): + global friendlymodelname self.send_response(200) self.send_header("Cache-Control", "no-cache") self.send_header("Connection", "keep-alive") - self.end_headers() + self.end_headers(force_json=True, sse_stream_flag=True) current_token = 0 - incomplete_token_buffer = bytearray() while True: streamDone = handle.has_finished() #exit next loop on done @@ -486,14 +519,20 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tokenStr += tokenSeg if tokenStr!="": - event_data = {"token": tokenStr} - event_str = json.dumps(event_data) + if api_format == 4: # if oai chat, set format to expected openai streaming response + event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) + await self.send_oai_sse_event(event_str) + else: + event_str = json.dumps({"token": tokenStr}) + await self.send_kai_sse_event(event_str) tokenStr = "" - await self.send_sse_event("message", event_str) + else: await asyncio.sleep(0.02) #this should keep things responsive if streamDone: + if api_format == 4: # if oai chat, send last [DONE] message consistent with openai format + await self.send_oai_sse_event('[DONE]') break # flush buffers, sleep a bit to make sure all data sent, and then force close the connection @@ -506,7 +545,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tasks = [] if stream_flag: - tasks.append(self.handle_sse_stream()) + tasks.append(self.handle_sse_stream(api_format)) generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag)) tasks.append(generate_task) @@ -570,8 +609,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode()) - elif self.path.endswith('/v1/models') or self.path.endswith('/models'): - response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) + elif self.path.endswith('/v1/models'): + response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) force_json = True elif self.path=="/api": @@ -604,7 +643,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): body = self.rfile.read(content_length) self.path = self.path.rstrip('/') force_json = False - if self.path.endswith(('/api/extra/tokencount')): try: genparams = json.loads(body) @@ -667,9 +705,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0 try: - kai_sse_stream_flag = False + sse_stream_flag = False - api_format = 0 #1=basic,2=kai,3=oai + api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat if self.path.endswith('/request'): api_format = 1 @@ -679,12 +717,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if self.path.endswith('/api/extra/generate/stream'): api_format = 2 - kai_sse_stream_flag = True + sse_stream_flag = True - if self.path.endswith('/v1/completions') or self.path.endswith('/completions'): + if self.path.endswith('/v1/completions'): api_format = 3 force_json = True + if self.path.endswith('/v1/chat/completions'): + api_format = 4 + force_json = True + if api_format>0: genparams = None try: @@ -699,11 +741,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if args.foreground: bring_terminal_to_foreground() - gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag)) + # Check if streaming chat completions, if so, set stream mode to true + if api_format == 4 and "stream" in genparams and genparams["stream"]: + sse_stream_flag = True + + gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) try: # Headers are already sent when streaming - if not kai_sse_stream_flag: + if not sse_stream_flag: self.send_response(200) self.end_headers(force_json=force_json) self.wfile.write(json.dumps(gen).encode()) @@ -726,12 +772,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.send_response(200) self.end_headers() - def end_headers(self, force_json=False): + def end_headers(self, force_json=False, sse_stream_flag=False): self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Headers', '*') if ("/api" in self.path and self.path!="/api") or force_json: - if self.path.endswith("/stream"): + if sse_stream_flag: self.send_header('Content-type', 'text/event-stream') self.send_header('Content-type', 'application/json') else: @@ -928,10 +974,12 @@ def show_new_gui(): x, y = root.winfo_pointerxy() tooltip.wm_geometry(f"+{x + 10}+{y + 10}") tooltip.deiconify() + def hide_tooltip(event): if hasattr(show_tooltip, "_tooltip"): tooltip = show_tooltip._tooltip tooltip.withdraw() + def setup_backend_tooltip(parent): num_backends_built = makelabel(parent, str(len(runopts)) + "/6", 5, 2) num_backends_built.grid(row=1, column=2, padx=0, pady=0) @@ -1106,7 +1154,6 @@ def show_new_gui(): for idx, name, in enumerate(token_boxes): makecheckbox(tokens_tab, name, token_boxes[name], idx + 1) - # context size makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=2) @@ -1799,4 +1846,4 @@ if __name__ == '__main__': parser.add_argument("--multiuser", help="Runs in multiuser mode, which queues incoming requests instead of blocking them.", action='store_true') parser.add_argument("--foreground", help="Windows only. Sends the terminal to the foreground every time a new prompt is generated. This helps avoid some idle slowdown issues.", action='store_true') - main(parser.parse_args(),start_server=True) \ No newline at end of file + main(parser.parse_args(),start_server=True)