added more fields to the openai compatible completions APIs
This commit is contained in:
parent
914e375602
commit
a3f708afce
1 changed files with 41 additions and 42 deletions
83
koboldcpp.py
83
koboldcpp.py
|
@ -425,68 +425,64 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
async def generate_text(self, genparams, api_format, stream_flag):
|
||||
global friendlymodelname
|
||||
def run_blocking():
|
||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 80)
|
||||
elif api_format==3:
|
||||
genparams["max_length"] = genparams.get('max', 100)
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
frqp = genparams.get('frequency_penalty', 0.1)
|
||||
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
||||
genparams["max_length"] = genparams.get('max_tokens', 80)
|
||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||
genparams["rep_pen"] = scaled_rep_pen
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
if isinstance(genparams.get('stop',[]), list):
|
||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||
else:
|
||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||
elif api_format==4:
|
||||
# translate openai chat completion messages format into one big string.
|
||||
messages_array = genparams.get('messages', [])
|
||||
adapter_obj = genparams.get('adapter', {})
|
||||
messages_string = ""
|
||||
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
|
||||
system_message_end = adapter_obj.get("system_end", "")
|
||||
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
|
||||
user_message_end = adapter_obj.get("user_end", "")
|
||||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
|
||||
for message in messages_array:
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_start
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_start
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_start
|
||||
genparams["sampler_seed"] = genparams.get('seed', -1)
|
||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
||||
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
||||
|
||||
messages_string += message['content']
|
||||
if api_format==4:
|
||||
# translate openai chat completion messages format into one big string.
|
||||
messages_array = genparams.get('messages', [])
|
||||
adapter_obj = genparams.get('adapter', {})
|
||||
messages_string = ""
|
||||
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
|
||||
system_message_end = adapter_obj.get("system_end", "")
|
||||
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
|
||||
user_message_end = adapter_obj.get("user_end", "")
|
||||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_end
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_end
|
||||
for message in messages_array:
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_start
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_start
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_start
|
||||
|
||||
messages_string += assistant_message_start
|
||||
messages_string += message['content']
|
||||
|
||||
genparams["prompt"] = messages_string
|
||||
frqp = genparams.get('frequency_penalty', 0.1)
|
||||
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
||||
genparams["max_length"] = genparams.get('max_tokens', 80)
|
||||
genparams["rep_pen"] = scaled_rep_pen
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
if isinstance(genparams.get('stop',[]), list):
|
||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||
else:
|
||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_end
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_end
|
||||
|
||||
messages_string += assistant_message_start
|
||||
genparams["prompt"] = messages_string
|
||||
|
||||
return generate(
|
||||
prompt=genparams.get('prompt', ""),
|
||||
memory=genparams.get('memory', ""),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 80),
|
||||
max_length=genparams.get('max_length', 100),
|
||||
temperature=genparams.get('temperature', 0.7),
|
||||
top_k=genparams.get('top_k', 100),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
|
@ -578,6 +574,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
elif api_format == 3: # non chat completions
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
else:
|
||||
event_str = json.dumps({"token": tokenStr})
|
||||
await self.send_kai_sse_event(event_str)
|
||||
|
@ -817,7 +816,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
bring_terminal_to_foreground()
|
||||
|
||||
# Check if streaming chat completions, if so, set stream mode to true
|
||||
if api_format == 4 and "stream" in genparams and genparams["stream"]:
|
||||
if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
|
||||
sse_stream_flag = True
|
||||
|
||||
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue