improve cors and header handling
This commit is contained in:
parent
f604cffdce
commit
c6fe820357
1 changed files with 28 additions and 37 deletions
65
koboldcpp.py
65
koboldcpp.py
|
@ -520,9 +520,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
async def handle_sse_stream(self, api_format):
|
async def handle_sse_stream(self, api_format):
|
||||||
global friendlymodelname
|
global friendlymodelname
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Cache-Control", "no-cache")
|
self.send_header("cache-control", "no-cache")
|
||||||
self.send_header("Connection", "keep-alive")
|
self.send_header("connection", "keep-alive")
|
||||||
self.end_headers(force_json=True, sse_stream_flag=True)
|
self.end_headers(content_type='text/event-stream')
|
||||||
|
|
||||||
current_token = 0
|
current_token = 0
|
||||||
incomplete_token_buffer = bytearray()
|
incomplete_token_buffer = bytearray()
|
||||||
|
@ -589,10 +589,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
|
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
response_body = None
|
response_body = None
|
||||||
force_json = False
|
content_type = 'application/json'
|
||||||
|
|
||||||
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
||||||
|
content_type = 'text/html'
|
||||||
if self.embedded_kailite is None:
|
if self.embedded_kailite is None:
|
||||||
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
|
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
|
||||||
else:
|
else:
|
||||||
|
@ -638,9 +638,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
elif self.path.endswith('/v1/models'):
|
elif self.path.endswith('/v1/models'):
|
||||||
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||||
force_json = True
|
|
||||||
|
|
||||||
elif self.path=="/api":
|
elif self.path=="/api":
|
||||||
|
content_type = 'text/html'
|
||||||
if self.embedded_kcpp_docs is None:
|
if self.embedded_kcpp_docs is None:
|
||||||
response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode()
|
response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode()
|
||||||
else:
|
else:
|
||||||
|
@ -648,41 +648,40 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
|
||||||
self.path = "/api"
|
self.path = "/api"
|
||||||
self.send_response(302)
|
self.send_response(302)
|
||||||
self.send_header("Location", self.path)
|
self.send_header("location", self.path)
|
||||||
self.end_headers()
|
self.end_headers(content_type='text/html')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if response_body is None:
|
if response_body is None:
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers(content_type='text/html')
|
||||||
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
||||||
self.wfile.write(rp.encode())
|
self.wfile.write(rp.encode())
|
||||||
else:
|
else:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Length', str(len(response_body)))
|
self.send_header('content-length', str(len(response_body)))
|
||||||
self.end_headers(force_json=force_json)
|
self.end_headers(content_type=content_type)
|
||||||
self.wfile.write(response_body)
|
self.wfile.write(response_body)
|
||||||
return
|
return
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
global modelbusy, requestsinqueue, currentusergenkey, totalgens
|
global modelbusy, requestsinqueue, currentusergenkey, totalgens
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['content-length'])
|
||||||
body = self.rfile.read(content_length)
|
body = self.rfile.read(content_length)
|
||||||
self.path = self.path.rstrip('/')
|
self.path = self.path.rstrip('/')
|
||||||
force_json = False
|
|
||||||
if self.path.endswith(('/api/extra/tokencount')):
|
if self.path.endswith(('/api/extra/tokencount')):
|
||||||
try:
|
try:
|
||||||
genparams = json.loads(body)
|
genparams = json.loads(body)
|
||||||
countprompt = genparams.get('prompt', "")
|
countprompt = genparams.get('prompt', "")
|
||||||
count = handle.token_count(countprompt.encode("UTF-8"))
|
count = handle.token_count(countprompt.encode("UTF-8"))
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps({"value": count}).encode())
|
self.wfile.write(json.dumps({"value": count}).encode())
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
utfprint("Count Tokens - Body Error: " + str(e))
|
utfprint("Count Tokens - Body Error: " + str(e))
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.end_headers()
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps({"value": -1}).encode())
|
self.wfile.write(json.dumps({"value": -1}).encode())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -699,7 +698,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
ag = handle.abort_generate()
|
ag = handle.abort_generate()
|
||||||
time.sleep(0.3) #short delay before replying
|
time.sleep(0.3) #short delay before replying
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
|
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
|
||||||
print("\nGeneration Aborted")
|
print("\nGeneration Aborted")
|
||||||
else:
|
else:
|
||||||
|
@ -721,7 +720,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
pendtxt = handle.get_pending_output()
|
pendtxt = handle.get_pending_output()
|
||||||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -731,7 +730,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
requestsinqueue += 1
|
requestsinqueue += 1
|
||||||
if not modelbusy.acquire(blocking=reqblocking):
|
if not modelbusy.acquire(blocking=reqblocking):
|
||||||
self.send_response(503)
|
self.send_response(503)
|
||||||
self.end_headers()
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps({"detail": {
|
self.wfile.write(json.dumps({"detail": {
|
||||||
"msg": "Server is busy; please try again later.",
|
"msg": "Server is busy; please try again later.",
|
||||||
"type": "service_unavailable",
|
"type": "service_unavailable",
|
||||||
|
@ -753,15 +752,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
if self.path.endswith('/api/extra/generate/stream'):
|
if self.path.endswith('/api/extra/generate/stream'):
|
||||||
api_format = 2
|
api_format = 2
|
||||||
sse_stream_flag = True
|
|
||||||
|
|
||||||
if self.path.endswith('/v1/completions'):
|
if self.path.endswith('/v1/completions'):
|
||||||
api_format = 3
|
api_format = 3
|
||||||
force_json = True
|
|
||||||
|
|
||||||
if self.path.endswith('/v1/chat/completions'):
|
if self.path.endswith('/v1/chat/completions'):
|
||||||
api_format = 4
|
api_format = 4
|
||||||
force_json = True
|
|
||||||
|
|
||||||
if api_format > 0:
|
if api_format > 0:
|
||||||
genparams = None
|
genparams = None
|
||||||
|
@ -787,7 +783,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
# Headers are already sent when streaming
|
# Headers are already sent when streaming
|
||||||
if not sse_stream_flag:
|
if not sse_stream_flag:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers(force_json=force_json)
|
self.end_headers(content_type='application/json')
|
||||||
self.wfile.write(json.dumps(gen).encode())
|
self.wfile.write(json.dumps(gen).encode())
|
||||||
except:
|
except:
|
||||||
print("Generate: The response could not be sent, maybe connection was terminated?")
|
print("Generate: The response could not be sent, maybe connection was terminated?")
|
||||||
|
@ -796,28 +792,23 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
modelbusy.release()
|
modelbusy.release()
|
||||||
|
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.end_headers()
|
self.end_headers(content_type='text/html')
|
||||||
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(content_type='text/html')
|
||||||
|
|
||||||
def do_HEAD(self):
|
def do_HEAD(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers(content_type='text/html')
|
||||||
|
|
||||||
def end_headers(self, force_json=False, sse_stream_flag=False):
|
def end_headers(self, content_type=None):
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('access-control-allow-origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
self.send_header('access-control-allow-methods', '*')
|
||||||
self.send_header('Access-Control-Allow-Headers', '*')
|
self.send_header('access-control-allow-headers', '*, Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Client-Agent, X-Fields, Content-Type, Authorization, X-Requested-With, X-HTTP-Method-Override, apikey, genkey')
|
||||||
if ("/api" in self.path and self.path!="/api") or force_json:
|
if content_type is not None:
|
||||||
if sse_stream_flag:
|
self.send_header('content-type', content_type)
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
|
||||||
else:
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
else:
|
|
||||||
self.send_header('Content-Type', 'text/html')
|
|
||||||
return super(ServerRequestHandler, self).end_headers()
|
return super(ServerRequestHandler, self).end_headers()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1506,7 +1497,7 @@ def run_horde_worker(args, api_key, worker_name):
|
||||||
if method=='POST':
|
if method=='POST':
|
||||||
json_payload = json.dumps(data).encode('utf-8')
|
json_payload = json.dumps(data).encode('utf-8')
|
||||||
request = urllib.request.Request(url, data=json_payload, headers=headers, method=method)
|
request = urllib.request.Request(url, data=json_payload, headers=headers, method=method)
|
||||||
request.add_header('Content-Type', 'application/json')
|
request.add_header('content-type', 'application/json')
|
||||||
else:
|
else:
|
||||||
request = urllib.request.Request(url, headers=headers, method=method)
|
request = urllib.request.Request(url, headers=headers, method=method)
|
||||||
response_data = ""
|
response_data = ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue