added more fields to the openai compatible completions APIs

This commit is contained in:
Concedo 2023-11-16 00:58:08 +08:00
parent 914e375602
commit a3f708afce

View file

@ -425,68 +425,64 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
async def generate_text(self, genparams, api_format, stream_flag):
global friendlymodelname
def run_blocking():
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 80)
elif api_format==3:
genparams["max_length"] = genparams.get('max', 100)
elif api_format==3 or api_format==4:
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 80)
genparams["max_length"] = genparams.get('max_tokens', 100)
genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
else:
genparams["stop_sequence"] = [genparams.get('stop')]
elif api_format==4:
# translate openai chat completion messages format into one big string.
messages_array = genparams.get('messages', [])
adapter_obj = genparams.get('adapter', {})
messages_string = ""
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
system_message_end = adapter_obj.get("system_end", "")
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
for message in messages_array:
if message['role'] == "system":
messages_string += system_message_start
elif message['role'] == "user":
messages_string += user_message_start
elif message['role'] == "assistant":
messages_string += assistant_message_start
genparams["sampler_seed"] = genparams.get('seed', -1)
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
messages_string += message['content']
if api_format==4:
# translate openai chat completion messages format into one big string.
messages_array = genparams.get('messages', [])
adapter_obj = genparams.get('adapter', {})
messages_string = ""
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
system_message_end = adapter_obj.get("system_end", "")
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
if message['role'] == "system":
messages_string += system_message_end
elif message['role'] == "user":
messages_string += user_message_end
elif message['role'] == "assistant":
messages_string += assistant_message_end
for message in messages_array:
if message['role'] == "system":
messages_string += system_message_start
elif message['role'] == "user":
messages_string += user_message_start
elif message['role'] == "assistant":
messages_string += assistant_message_start
messages_string += assistant_message_start
messages_string += message['content']
genparams["prompt"] = messages_string
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 80)
genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
else:
genparams["stop_sequence"] = [genparams.get('stop')]
if message['role'] == "system":
messages_string += system_message_end
elif message['role'] == "user":
messages_string += user_message_end
elif message['role'] == "assistant":
messages_string += assistant_message_end
messages_string += assistant_message_start
genparams["prompt"] = messages_string
return generate(
prompt=genparams.get('prompt', ""),
memory=genparams.get('memory', ""),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80),
max_length=genparams.get('max_length', 100),
temperature=genparams.get('temperature', 0.7),
top_k=genparams.get('top_k', 100),
top_a=genparams.get('top_a', 0.0),
@ -578,6 +574,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if api_format == 4: # if oai chat, set format to expected openai streaming response
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
await self.send_oai_sse_event(event_str)
elif api_format == 3: # non chat completions
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
await self.send_oai_sse_event(event_str)
else:
event_str = json.dumps({"token": tokenStr})
await self.send_kai_sse_event(event_str)
@ -817,7 +816,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
bring_terminal_to_foreground()
# Check if streaming chat completions, if so, set stream mode to true
if api_format == 4 and "stream" in genparams and genparams["stream"]:
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
sse_stream_flag = True
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))