From a3f708afce4324e03b36a7915a0a2da80057fb1d Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 16 Nov 2023 00:58:08 +0800 Subject: [PATCH] added more fields to the openai compatible completions APIs --- koboldcpp.py | 83 ++++++++++++++++++++++++++-------------------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 7a04cfadd..af3e22edc 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -425,68 +425,64 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): async def generate_text(self, genparams, api_format, stream_flag): global friendlymodelname - def run_blocking(): + def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat if api_format==1: genparams["prompt"] = genparams.get('text', "") genparams["top_k"] = int(genparams.get('top_k', 120)) - genparams["max_length"] = genparams.get('max', 80) - elif api_format==3: + genparams["max_length"] = genparams.get('max', 100) + + elif api_format==3 or api_format==4: frqp = genparams.get('frequency_penalty', 0.1) scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 - genparams["max_length"] = genparams.get('max_tokens', 80) + genparams["max_length"] = genparams.get('max_tokens', 100) genparams["rep_pen"] = scaled_rep_pen # openai allows either a string or a list as a stop sequence if isinstance(genparams.get('stop',[]), list): genparams["stop_sequence"] = genparams.get('stop', []) else: genparams["stop_sequence"] = [genparams.get('stop')] - elif api_format==4: - # translate openai chat completion messages format into one big string. - messages_array = genparams.get('messages', []) - adapter_obj = genparams.get('adapter', {}) - messages_string = "" - system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") - system_message_end = adapter_obj.get("system_end", "") - user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") - user_message_end = adapter_obj.get("user_end", "") - assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") - assistant_message_end = adapter_obj.get("assistant_end", "") - for message in messages_array: - if message['role'] == "system": - messages_string += system_message_start - elif message['role'] == "user": - messages_string += user_message_start - elif message['role'] == "assistant": - messages_string += assistant_message_start + genparams["sampler_seed"] = genparams.get('seed', -1) + genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) + genparams["mirostat"] = genparams.get('mirostat_mode', 0) - messages_string += message['content'] + if api_format==4: + # translate openai chat completion messages format into one big string. + messages_array = genparams.get('messages', []) + adapter_obj = genparams.get('adapter', {}) + messages_string = "" + system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") + system_message_end = adapter_obj.get("system_end", "") + user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") + user_message_end = adapter_obj.get("user_end", "") + assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") + assistant_message_end = adapter_obj.get("assistant_end", "") - if message['role'] == "system": - messages_string += system_message_end - elif message['role'] == "user": - messages_string += user_message_end - elif message['role'] == "assistant": - messages_string += assistant_message_end + for message in messages_array: + if message['role'] == "system": + messages_string += system_message_start + elif message['role'] == "user": + messages_string += user_message_start + elif message['role'] == "assistant": + messages_string += assistant_message_start - messages_string += assistant_message_start + messages_string += message['content'] - genparams["prompt"] = messages_string - frqp = genparams.get('frequency_penalty', 0.1) - scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 - genparams["max_length"] = genparams.get('max_tokens', 80) - genparams["rep_pen"] = scaled_rep_pen - # openai allows either a string or a list as a stop sequence - if isinstance(genparams.get('stop',[]), list): - genparams["stop_sequence"] = genparams.get('stop', []) - else: - genparams["stop_sequence"] = [genparams.get('stop')] + if message['role'] == "system": + messages_string += system_message_end + elif message['role'] == "user": + messages_string += user_message_end + elif message['role'] == "assistant": + messages_string += assistant_message_end + + messages_string += assistant_message_start + genparams["prompt"] = messages_string return generate( prompt=genparams.get('prompt', ""), memory=genparams.get('memory', ""), max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 80), + max_length=genparams.get('max_length', 100), temperature=genparams.get('temperature', 0.7), top_k=genparams.get('top_k', 100), top_a=genparams.get('top_a', 0.0), @@ -578,6 +574,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if api_format == 4: # if oai chat, set format to expected openai streaming response event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) await self.send_oai_sse_event(event_str) + elif api_format == 3: # non chat completions + event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]}) + await self.send_oai_sse_event(event_str) else: event_str = json.dumps({"token": tokenStr}) await self.send_kai_sse_event(event_str) @@ -817,7 +816,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): bring_terminal_to_foreground() # Check if streaming chat completions, if so, set stream mode to true - if api_format == 4 and "stream" in genparams and genparams["stream"]: + if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]: sse_stream_flag = True gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))