force oai endpoints to return json
This commit is contained in:
parent
0c47e79537
commit
23b9d3af49
1 changed files with 8 additions and 4 deletions
12
koboldcpp.py
12
koboldcpp.py
|
@ -526,6 +526,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
|
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
response_body = None
|
response_body = None
|
||||||
|
force_json = False
|
||||||
|
|
||||||
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
||||||
if args.stream and not "streaming=1" in self.path:
|
if args.stream and not "streaming=1" in self.path:
|
||||||
|
@ -585,6 +586,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
|
elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
|
||||||
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||||
|
force_json = True
|
||||||
|
|
||||||
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
||||||
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
|
||||||
|
@ -598,7 +600,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
else:
|
else:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Length', str(len(response_body)))
|
self.send_header('Content-Length', str(len(response_body)))
|
||||||
self.end_headers()
|
self.end_headers(force_json=force_json)
|
||||||
self.wfile.write(response_body)
|
self.wfile.write(response_body)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -607,6 +609,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = self.rfile.read(content_length)
|
body = self.rfile.read(content_length)
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
|
force_json = False
|
||||||
|
|
||||||
if self.path.endswith(('/api/extra/tokencount')):
|
if self.path.endswith(('/api/extra/tokencount')):
|
||||||
try:
|
try:
|
||||||
|
@ -686,6 +689,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
|
if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
|
||||||
api_format = 3
|
api_format = 3
|
||||||
|
force_json = True
|
||||||
|
|
||||||
if api_format>0:
|
if api_format>0:
|
||||||
genparams = None
|
genparams = None
|
||||||
|
@ -707,7 +711,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
# Headers are already sent when streaming
|
# Headers are already sent when streaming
|
||||||
if not kai_sse_stream_flag:
|
if not kai_sse_stream_flag:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(force_json=force_json)
|
||||||
self.wfile.write(json.dumps(gen).encode())
|
self.wfile.write(json.dumps(gen).encode())
|
||||||
except:
|
except:
|
||||||
print("Generate: The response could not be sent, maybe connection was terminated?")
|
print("Generate: The response could not be sent, maybe connection was terminated?")
|
||||||
|
@ -728,11 +732,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def end_headers(self):
|
def end_headers(self, force_json=False):
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
self.send_header('Access-Control-Allow-Methods', '*')
|
||||||
self.send_header('Access-Control-Allow-Headers', '*')
|
self.send_header('Access-Control-Allow-Headers', '*')
|
||||||
if "/api" in self.path:
|
if "/api" in self.path or force_json:
|
||||||
if self.path.endswith("/stream"):
|
if self.path.endswith("/stream"):
|
||||||
self.send_header('Content-type', 'text/event-stream')
|
self.send_header('Content-type', 'text/event-stream')
|
||||||
self.send_header('Content-type', 'application/json')
|
self.send_header('Content-type', 'application/json')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue