added simulated OAI endpoint
This commit is contained in:
parent
7f112e2cd4
commit
8bf6f7f8b0
1 changed files with 53 additions and 61 deletions
78
koboldcpp.py
78
koboldcpp.py
|
@ -381,36 +381,22 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
super().log_message(format, *args)
|
super().log_message(format, *args)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag):
|
async def generate_text(self, genparams, api_format, stream_flag):
|
||||||
|
|
||||||
def run_blocking():
|
def run_blocking():
|
||||||
if basic_api_flag:
|
if api_format==1:
|
||||||
return generate(
|
genparams["prompt"] = genparams.get('text', "")
|
||||||
prompt=newprompt,
|
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||||
max_length=genparams.get('max', 50),
|
genparams["max_length"]=genparams.get('max', 50)
|
||||||
temperature=genparams.get('temperature', 0.8),
|
elif api_format==3:
|
||||||
top_k=int(genparams.get('top_k', 120)),
|
scaled_rep_pen = genparams.get('presence_penalty', 0.1) + 1
|
||||||
top_a=genparams.get('top_a', 0.0),
|
genparams["max_length"] = genparams.get('max_tokens', 50)
|
||||||
top_p=genparams.get('top_p', 0.85),
|
genparams["rep_pen"] = scaled_rep_pen
|
||||||
typical_p=genparams.get('typical', 1.0),
|
|
||||||
tfs=genparams.get('tfs', 1.0),
|
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
|
||||||
mirostat=genparams.get('mirostat', 0),
|
|
||||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
|
||||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
|
||||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
|
||||||
seed=genparams.get('sampler_seed', -1),
|
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
|
||||||
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
|
|
||||||
stream_sse=stream_flag,
|
|
||||||
grammar=genparams.get('grammar', ''),
|
|
||||||
genkey=genparams.get('genkey', ''))
|
|
||||||
|
|
||||||
else:
|
return generate(
|
||||||
return generate(prompt=newprompt,
|
prompt=genparams.get('prompt', ""),
|
||||||
max_context_length=genparams.get('max_context_length', maxctx),
|
max_context_length=genparams.get('max_context_length', maxctx),
|
||||||
max_length=genparams.get('max_length', 50),
|
max_length=genparams.get('max_length', 80),
|
||||||
temperature=genparams.get('temperature', 0.8),
|
temperature=genparams.get('temperature', 0.8),
|
||||||
top_k=genparams.get('top_k', 120),
|
top_k=genparams.get('top_k', 120),
|
||||||
top_a=genparams.get('top_a', 0.0),
|
top_a=genparams.get('top_a', 0.0),
|
||||||
|
@ -418,7 +404,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
typical_p=genparams.get('typical', 1.0),
|
typical_p=genparams.get('typical', 1.0),
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
rep_pen_range=genparams.get('rep_pen_range', 256),
|
||||||
mirostat=genparams.get('mirostat', 0),
|
mirostat=genparams.get('mirostat', 0),
|
||||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||||
|
@ -441,7 +427,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
if args.debugmode!=-1:
|
if args.debugmode!=-1:
|
||||||
utfprint("\nOutput: " + recvtxt)
|
utfprint("\nOutput: " + recvtxt)
|
||||||
|
|
||||||
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
|
if api_format==1:
|
||||||
|
res = {"data": {"seqs":[recvtxt]}}
|
||||||
|
elif api_format==3:
|
||||||
|
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp",
|
||||||
|
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
|
||||||
|
else:
|
||||||
|
res = {"results": [{"text": recvtxt}]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return res
|
return res
|
||||||
|
@ -490,13 +482,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
self.close_connection = True
|
self.close_connection = True
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag):
|
async def handle_request(self, genparams, api_format, stream_flag):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
tasks.append(self.handle_sse_stream())
|
tasks.append(self.handle_sse_stream())
|
||||||
|
|
||||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag))
|
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
||||||
tasks.append(generate_task)
|
tasks.append(generate_task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -568,6 +560,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||||
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||||
|
|
||||||
|
elif self.path.endswith('/api/extra/oai/v1/models'):
|
||||||
|
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||||
|
|
||||||
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
||||||
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
||||||
|
@ -652,20 +646,24 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
|
requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
basic_api_flag = False
|
|
||||||
kai_api_flag = False
|
|
||||||
kai_sse_stream_flag = False
|
kai_sse_stream_flag = False
|
||||||
|
|
||||||
|
api_format = 0 #1=basic,2=kai,3=oai
|
||||||
|
|
||||||
if self.path.endswith('/request'):
|
if self.path.endswith('/request'):
|
||||||
basic_api_flag = True
|
api_format = 1
|
||||||
|
|
||||||
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
||||||
kai_api_flag = True
|
api_format = 2
|
||||||
|
|
||||||
if self.path.endswith('/api/extra/generate/stream'):
|
if self.path.endswith('/api/extra/generate/stream'):
|
||||||
kai_api_flag = True
|
api_format = 2
|
||||||
kai_sse_stream_flag = True
|
kai_sse_stream_flag = True
|
||||||
|
|
||||||
if basic_api_flag or kai_api_flag:
|
if self.path.endswith('/api/extra/oai/v1/completions'):
|
||||||
|
api_format = 3
|
||||||
|
|
||||||
|
if api_format>0:
|
||||||
genparams = None
|
genparams = None
|
||||||
try:
|
try:
|
||||||
genparams = json.loads(body)
|
genparams = json.loads(body)
|
||||||
|
@ -676,13 +674,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
if args.debugmode!=-1:
|
if args.debugmode!=-1:
|
||||||
utfprint("\nInput: " + json.dumps(genparams))
|
utfprint("\nInput: " + json.dumps(genparams))
|
||||||
|
|
||||||
if kai_api_flag:
|
gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag))
|
||||||
fullprompt = genparams.get('prompt', "")
|
|
||||||
else:
|
|
||||||
fullprompt = genparams.get('text', "")
|
|
||||||
newprompt = fullprompt
|
|
||||||
|
|
||||||
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Headers are already sent when streaming
|
# Headers are already sent when streaming
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue