added simulated OAI endpoint

This commit is contained in:
Concedo 2023-09-27 00:49:24 +08:00
parent 7f112e2cd4
commit 8bf6f7f8b0

View file

@ -381,54 +381,40 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
super().log_message(format, *args)
pass
async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag):
async def generate_text(self, genparams, api_format, stream_flag):
def run_blocking():
if basic_api_flag:
return generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1),
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
genkey=genparams.get('genkey', ''))
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"]=genparams.get('max', 50)
elif api_format==3:
scaled_rep_pen = genparams.get('presence_penalty', 0.1) + 1
genparams["max_length"] = genparams.get('max_tokens', 50)
genparams["rep_pen"] = scaled_rep_pen
else:
return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1),
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
genkey=genparams.get('genkey', ''))
return generate(
prompt=genparams.get('prompt', ""),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 256),
mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1),
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
genkey=genparams.get('genkey', ''))
recvtxt = ""
if stream_flag:
@ -441,7 +427,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if args.debugmode!=-1:
utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
if api_format==1:
res = {"data": {"seqs":[recvtxt]}}
elif api_format==3:
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp",
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
else:
res = {"results": [{"text": recvtxt}]}
try:
return res
@ -490,13 +482,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.close_connection = True
async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag):
async def handle_request(self, genparams, api_format, stream_flag):
tasks = []
if stream_flag:
tasks.append(self.handle_sse_stream())
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag))
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
tasks.append(generate_task)
try:
@ -568,6 +560,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
elif self.path.endswith('/api/extra/oai/v1/models'):
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
@ -652,20 +646,24 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
try:
basic_api_flag = False
kai_api_flag = False
kai_sse_stream_flag = False
api_format = 0 #1=basic,2=kai,3=oai
if self.path.endswith('/request'):
basic_api_flag = True
api_format = 1
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True
api_format = 2
if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True
api_format = 2
kai_sse_stream_flag = True
if basic_api_flag or kai_api_flag:
if self.path.endswith('/api/extra/oai/v1/completions'):
api_format = 3
if api_format>0:
genparams = None
try:
genparams = json.loads(body)
@ -676,13 +674,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if args.debugmode!=-1:
utfprint("\nInput: " + json.dumps(genparams))
if kai_api_flag:
fullprompt = genparams.get('prompt', "")
else:
fullprompt = genparams.get('text', "")
newprompt = fullprompt
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag))
try:
# Headers are already sent when streaming