added more fields to the openai compatible completions APIs

This commit is contained in:
Concedo 2023-11-16 00:58:08 +08:00
parent 914e375602
commit a3f708afce

View file

@ -425,22 +425,28 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
async def generate_text(self, genparams, api_format, stream_flag): async def generate_text(self, genparams, api_format, stream_flag):
global friendlymodelname global friendlymodelname
def run_blocking(): def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
if api_format==1: if api_format==1:
genparams["prompt"] = genparams.get('text', "") genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120)) genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 80) genparams["max_length"] = genparams.get('max', 100)
elif api_format==3:
elif api_format==3 or api_format==4:
frqp = genparams.get('frequency_penalty', 0.1) frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1 scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 80) genparams["max_length"] = genparams.get('max_tokens', 100)
genparams["rep_pen"] = scaled_rep_pen genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence # openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list): if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', []) genparams["stop_sequence"] = genparams.get('stop', [])
else: else:
genparams["stop_sequence"] = [genparams.get('stop')] genparams["stop_sequence"] = [genparams.get('stop')]
elif api_format==4:
genparams["sampler_seed"] = genparams.get('seed', -1)
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
if api_format==4:
# translate openai chat completion messages format into one big string. # translate openai chat completion messages format into one big string.
messages_array = genparams.get('messages', []) messages_array = genparams.get('messages', [])
adapter_obj = genparams.get('adapter', {}) adapter_obj = genparams.get('adapter', {})
@ -470,23 +476,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
messages_string += assistant_message_end messages_string += assistant_message_end
messages_string += assistant_message_start messages_string += assistant_message_start
genparams["prompt"] = messages_string genparams["prompt"] = messages_string
frqp = genparams.get('frequency_penalty', 0.1)
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
genparams["max_length"] = genparams.get('max_tokens', 80)
genparams["rep_pen"] = scaled_rep_pen
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])
else:
genparams["stop_sequence"] = [genparams.get('stop')]
return generate( return generate(
prompt=genparams.get('prompt', ""), prompt=genparams.get('prompt', ""),
memory=genparams.get('memory', ""), memory=genparams.get('memory', ""),
max_context_length=genparams.get('max_context_length', maxctx), max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80), max_length=genparams.get('max_length', 100),
temperature=genparams.get('temperature', 0.7), temperature=genparams.get('temperature', 0.7),
top_k=genparams.get('top_k', 100), top_k=genparams.get('top_k', 100),
top_a=genparams.get('top_a', 0.0), top_a=genparams.get('top_a', 0.0),
@ -578,6 +574,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if api_format == 4: # if oai chat, set format to expected openai streaming response if api_format == 4: # if oai chat, set format to expected openai streaming response
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
await self.send_oai_sse_event(event_str) await self.send_oai_sse_event(event_str)
elif api_format == 3: # non chat completions
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
await self.send_oai_sse_event(event_str)
else: else:
event_str = json.dumps({"token": tokenStr}) event_str = json.dumps({"token": tokenStr})
await self.send_kai_sse_event(event_str) await self.send_kai_sse_event(event_str)
@ -817,7 +816,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
bring_terminal_to_foreground() bring_terminal_to_foreground()
# Check if streaming chat completions, if so, set stream mode to true # Check if streaming chat completions, if so, set stream mode to true
if api_format == 4 and "stream" in genparams and genparams["stream"]: if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]:
sse_stream_flag = True sse_stream_flag = True
gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag))