added simulated OAI endpoint
This commit is contained in:
parent
7f112e2cd4
commit
8bf6f7f8b0
1 changed files with 53 additions and 61 deletions
78
koboldcpp.py
78
koboldcpp.py
|
@ -381,36 +381,22 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
super().log_message(format, *args)
|
||||
pass
|
||||
|
||||
async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag):
|
||||
async def generate_text(self, genparams, api_format, stream_flag):
|
||||
|
||||
def run_blocking():
|
||||
if basic_api_flag:
|
||||
return generate(
|
||||
prompt=newprompt,
|
||||
max_length=genparams.get('max', 50),
|
||||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=int(genparams.get('top_k', 120)),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
top_p=genparams.get('top_p', 0.85),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
||||
stream_sse=stream_flag,
|
||||
grammar=genparams.get('grammar', ''),
|
||||
genkey=genparams.get('genkey', ''))
|
||||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"]=genparams.get('max', 50)
|
||||
elif api_format==3:
|
||||
scaled_rep_pen = genparams.get('presence_penalty', 0.1) + 1
|
||||
genparams["max_length"] = genparams.get('max_tokens', 50)
|
||||
genparams["rep_pen"] = scaled_rep_pen
|
||||
|
||||
else:
|
||||
return generate(prompt=newprompt,
|
||||
return generate(
|
||||
prompt=genparams.get('prompt', ""),
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 50),
|
||||
max_length=genparams.get('max_length', 80),
|
||||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=genparams.get('top_k', 120),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
|
@ -418,7 +404,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 256),
|
||||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
|
@ -441,7 +427,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if args.debugmode!=-1:
|
||||
utfprint("\nOutput: " + recvtxt)
|
||||
|
||||
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
|
||||
if api_format==1:
|
||||
res = {"data": {"seqs":[recvtxt]}}
|
||||
elif api_format==3:
|
||||
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp",
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
|
||||
else:
|
||||
res = {"results": [{"text": recvtxt}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
|
@ -490,13 +482,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
self.close_connection = True
|
||||
|
||||
|
||||
async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag):
|
||||
async def handle_request(self, genparams, api_format, stream_flag):
|
||||
tasks = []
|
||||
|
||||
if stream_flag:
|
||||
tasks.append(self.handle_sse_stream())
|
||||
|
||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag))
|
||||
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
||||
tasks.append(generate_task)
|
||||
|
||||
try:
|
||||
|
@ -568,6 +560,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/oai/v1/models'):
|
||||
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||
|
||||
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
||||
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
||||
|
@ -652,20 +646,24 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
|
||||
|
||||
try:
|
||||
basic_api_flag = False
|
||||
kai_api_flag = False
|
||||
kai_sse_stream_flag = False
|
||||
|
||||
api_format = 0 #1=basic,2=kai,3=oai
|
||||
|
||||
if self.path.endswith('/request'):
|
||||
basic_api_flag = True
|
||||
api_format = 1
|
||||
|
||||
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
||||
kai_api_flag = True
|
||||
api_format = 2
|
||||
|
||||
if self.path.endswith('/api/extra/generate/stream'):
|
||||
kai_api_flag = True
|
||||
api_format = 2
|
||||
kai_sse_stream_flag = True
|
||||
|
||||
if basic_api_flag or kai_api_flag:
|
||||
if self.path.endswith('/api/extra/oai/v1/completions'):
|
||||
api_format = 3
|
||||
|
||||
if api_format>0:
|
||||
genparams = None
|
||||
try:
|
||||
genparams = json.loads(body)
|
||||
|
@ -676,13 +674,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if args.debugmode!=-1:
|
||||
utfprint("\nInput: " + json.dumps(genparams))
|
||||
|
||||
if kai_api_flag:
|
||||
fullprompt = genparams.get('prompt', "")
|
||||
else:
|
||||
fullprompt = genparams.get('text', "")
|
||||
newprompt = fullprompt
|
||||
|
||||
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
|
||||
gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag))
|
||||
|
||||
try:
|
||||
# Headers are already sent when streaming
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue