From f9f4cdf3c08f1cf0d2272a04821188b1eb1df03e Mon Sep 17 00:00:00 2001 From: teddybear082 <87204721+teddybear082@users.noreply.github.com> Date: Thu, 5 Oct 2023 08:13:10 -0400 Subject: [PATCH] Implement basic chat/completions openai endpoint (#461) * Implement basic chat/completions openai endpoint -Basic support for openai chat/completions endpoint documented at: https://platform.openai.com/docs/api-reference/chat/create -Tested with example code from openai for chat/completions and chat/completions with stream=True parameter found here: https://cookbook.openai.com/examples/how_to_stream_completions. -Tested with Mantella, the skyrim mod that turns all the NPC's into AI chattable characters, which uses openai's acreate / async competions method: https://github.com/art-from-the-machine/Mantella/blob/main/src/output_manager.py -Tested default koboldcpp api behavior with streaming and non-streaming generate endpoints and running GUI and seems to be fine. -Still TODO / evaluate before merging: (1) implement rest of openai chat/completion parameters to the extent possible, mapping to koboldcpp parameters (2) determine if there is a way to use kobold's prompt formats for certain models when translating openai messages format into a prompt string. (Not sure if possible or where these are in the code) (3) have chat/completions responses include the actual local model the user is using instead of just koboldcpp (Not sure if this is possible) Note I am a python noob, so if there is a more elegant way of doing this at minimum hopefully I have done some of the grunt work for you to implement on your own. * Fix typographical error on deleted streaming argument -Mistakenly left code relating to streaming argument from main branch in experimental. * add additional openai chat completions parameters -support stop parameter mapped to koboldai stop_sequence parameter -make default max_length / max_tokens parameter consistent with default 80 token length in generate function -add support for providing name of local model in openai responses * Revert "add additional openai chat completions parameters" This reverts commit 443a6f7ff6346f41c78b0a6ff59c063999542327. * add additional openai chat completions parameters -support stop parameter mapped to koboldai stop_sequence parameter -make default max_length / max_tokens parameter consistent with default 80 token length in generate function -add support for providing name of local model in openai responses * add /n after formatting prompts from openaiformat to conform with alpaca standard used as default in lite.koboldai.net * tidy up and simplify code, do not set globals for streaming * oai endpoints must start with v1 --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> --- koboldcpp.py | 99 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index f8f6f2634..80547f724 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -391,17 +391,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pass async def generate_text(self, genparams, api_format, stream_flag): - + global friendlymodelname def run_blocking(): if api_format==1: genparams["prompt"] = genparams.get('text', "") genparams["top_k"] = int(genparams.get('top_k', 120)) - genparams["max_length"]=genparams.get('max', 50) + genparams["max_length"] = genparams.get('max', 80) elif api_format==3: frqp = genparams.get('frequency_penalty', 0.1) scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 - genparams["max_length"] = genparams.get('max_tokens', 50) + genparams["max_length"] = genparams.get('max_tokens', 80) genparams["rep_pen"] = scaled_rep_pen + # openai allows either a string or a list as a stop sequence + if isinstance(genparams.get('stop',[]), list): + genparams["stop_sequence"] = genparams.get('stop', []) + else: + genparams["stop_sequence"] = [genparams.get('stop')] + elif api_format==4: + # translate openai chat completion messages format into one big string. + messages_array = genparams.get('messages', []) + messages_string = "" + for message in messages_array: + if message['role'] == "system": + messages_string+="\n### Instruction:\n" + elif message['role'] == "user": + messages_string+="\n### Instruction:\n" + elif message['role'] == "assistant": + messages_string+="\n### Response:\n" + messages_string+=message['content'] + messages_string += "\n### Response:\n" + genparams["prompt"] = messages_string + frqp = genparams.get('frequency_penalty', 0.1) + scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 + genparams["max_length"] = genparams.get('max_tokens', 80) + genparams["rep_pen"] = scaled_rep_pen + # openai allows either a string or a list as a stop sequence + if isinstance(genparams.get('stop',[]), list): + genparams["stop_sequence"] = genparams.get('stop', []) + else: + genparams["stop_sequence"] = [genparams.get('stop')] return generate( prompt=genparams.get('prompt', ""), @@ -441,8 +469,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if api_format==1: res = {"data": {"seqs":[recvtxt]}} elif api_format==3: - res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp", + res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname, "choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]} + elif api_format==4: + res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname, + "choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]} else: res = {"results": [{"text": recvtxt}]} @@ -452,19 +483,21 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): print(f"Generate: Error while generating: {e}") - async def send_sse_event(self, event, data): - self.wfile.write(f'event: {event}\n'.encode()) + async def send_oai_sse_event(self, data): + self.wfile.write(f'data: {data}\r\n\r\n'.encode()) + + async def send_kai_sse_event(self, data): + self.wfile.write(f'event: message\n'.encode()) self.wfile.write(f'data: {data}\n\n'.encode()) - - async def handle_sse_stream(self): + async def handle_sse_stream(self, api_format): + global friendlymodelname self.send_response(200) self.send_header("Cache-Control", "no-cache") self.send_header("Connection", "keep-alive") - self.end_headers() + self.end_headers(force_json=True, sse_stream_flag=True) current_token = 0 - incomplete_token_buffer = bytearray() while True: streamDone = handle.has_finished() #exit next loop on done @@ -485,14 +518,20 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tokenStr += tokenSeg if tokenStr!="": - event_data = {"token": tokenStr} - event_str = json.dumps(event_data) + if api_format == 4: # if oai chat, set format to expected openai streaming response + event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) + await self.send_oai_sse_event(event_str) + else: + event_str = json.dumps({"token": tokenStr}) + await self.send_kai_sse_event(event_str) tokenStr = "" - await self.send_sse_event("message", event_str) + else: await asyncio.sleep(0.02) #this should keep things responsive if streamDone: + if api_format == 4: # if oai chat, send last [DONE] message consistent with openai format + await self.send_oai_sse_event('[DONE]') break # flush buffers, sleep a bit to make sure all data sent, and then force close the connection @@ -505,7 +544,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tasks = [] if stream_flag: - tasks.append(self.handle_sse_stream()) + tasks.append(self.handle_sse_stream(api_format)) generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag)) tasks.append(generate_task) @@ -569,8 +608,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode()) - elif self.path.endswith('/v1/models') or self.path.endswith('/models'): - response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) + elif self.path.endswith('/v1/models'): + response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) force_json = True elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): @@ -595,7 +634,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): body = self.rfile.read(content_length) self.path = self.path.rstrip('/') force_json = False - if self.path.endswith(('/api/extra/tokencount')): try: genparams = json.loads(body) @@ -658,9 +696,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0 try: - kai_sse_stream_flag = False + sse_stream_flag = False - api_format = 0 #1=basic,2=kai,3=oai + api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat if self.path.endswith('/request'): api_format = 1 @@ -670,12 +708,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if self.path.endswith('/api/extra/generate/stream'): api_format = 2 - kai_sse_stream_flag = True + sse_stream_flag = True - if self.path.endswith('/v1/completions') or self.path.endswith('/completions'): + if self.path.endswith('/v1/completions'): api_format = 3 force_json = True + if self.path.endswith('/v1/chat/completions'): + api_format = 4 + force_json = True + if api_format>0: genparams = None try: @@ -690,11 +732,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if args.foreground: bring_terminal_to_foreground() - gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag)) + # Check if streaming chat completions, if so, set stream mode to true + if api_format == 4 and "stream" in genparams and genparams["stream"]: + sse_stream_flag = True + + gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) try: # Headers are already sent when streaming - if not kai_sse_stream_flag: + if not sse_stream_flag: self.send_response(200) self.end_headers(force_json=force_json) self.wfile.write(json.dumps(gen).encode()) @@ -717,12 +763,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.send_response(200) self.end_headers() - def end_headers(self, force_json=False): + def end_headers(self, force_json=False, sse_stream_flag=False): self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Headers', '*') if "/api" in self.path or force_json: - if self.path.endswith("/stream"): + if sse_stream_flag: self.send_header('Content-type', 'text/event-stream') self.send_header('Content-type', 'application/json') else: @@ -919,10 +965,12 @@ def show_new_gui(): x, y = root.winfo_pointerxy() tooltip.wm_geometry(f"+{x + 10}+{y + 10}") tooltip.deiconify() + def hide_tooltip(event): if hasattr(show_tooltip, "_tooltip"): tooltip = show_tooltip._tooltip tooltip.withdraw() + def setup_backend_tooltip(parent): num_backends_built = makelabel(parent, str(len(runopts)) + "/6", 5, 2) num_backends_built.grid(row=1, column=2, padx=0, pady=0) @@ -1097,7 +1145,6 @@ def show_new_gui(): for idx, name, in enumerate(token_boxes): makecheckbox(tokens_tab, name, token_boxes[name], idx + 1) - # context size makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=2)