From 96e9539f05ee61386667b5eea1d74a9a98b94cd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20=C5=A0t=C3=A1gl?= Date: Mon, 9 Oct 2023 17:24:48 +0200 Subject: [PATCH] OpenAI compat API adapter (#466) * feat: oai-adapter * simplify optional adapter for instruct start and end tags --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> --- koboldcpp.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 0b0b9f31f..1957fd09b 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -412,16 +412,34 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): elif api_format==4: # translate openai chat completion messages format into one big string. messages_array = genparams.get('messages', []) + adapter_obj = genparams.get('adapter', {}) messages_string = "" + system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") + system_message_end = adapter_obj.get("system_end", "") + user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") + user_message_end = adapter_obj.get("user_end", "") + assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") + assistant_message_end = adapter_obj.get("assistant_end", "") + for message in messages_array: if message['role'] == "system": - messages_string+="\n### Instruction:\n" + messages_string += system_message_start elif message['role'] == "user": - messages_string+="\n### Instruction:\n" + messages_string += user_message_start elif message['role'] == "assistant": - messages_string+="\n### Response:\n" - messages_string+=message['content'] - messages_string += "\n### Response:\n" + messages_string += assistant_message_start + + messages_string += message['content'] + + if message['role'] == "system": + messages_string += system_message_end + elif message['role'] == "user": + messages_string += user_message_end + elif message['role'] == "assistant": + messages_string += assistant_message_end + + messages_string += assistant_message_start + genparams["prompt"] = messages_string frqp = genparams.get('frequency_penalty', 0.1) scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1