force oai endpoints to return json

This commit is contained in:
Concedo 2023-10-02 12:45:14 +08:00
parent 0c47e79537
commit 23b9d3af49

View file

@ -526,6 +526,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
self.path = self.path.rstrip('/')
response_body = None
force_json = False
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
if args.stream and not "streaming=1" in self.path:
@ -585,6 +586,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
force_json = True
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
@ -598,7 +600,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.end_headers(force_json=force_json)
self.wfile.write(response_body)
return
@ -607,6 +609,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
self.path = self.path.rstrip('/')
force_json = False
if self.path.endswith(('/api/extra/tokencount')):
try:
@ -686,6 +689,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
api_format = 3
force_json = True
if api_format>0:
genparams = None
@ -707,7 +711,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
# Headers are already sent when streaming
if not kai_sse_stream_flag:
self.send_response(200)
self.end_headers()
self.end_headers(force_json=force_json)
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
@ -728,11 +732,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.send_response(200)
self.end_headers()
def end_headers(self):
def end_headers(self, force_json=False):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path:
if "/api" in self.path or force_json:
if self.path.endswith("/stream"):
self.send_header('Content-type', 'text/event-stream')
self.send_header('Content-type', 'application/json')