From c6fe820357fbebf98ccf6146d9363ac67d77771d Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 12 Oct 2023 14:53:39 +0800 Subject: [PATCH] improve cors and header handling --- koboldcpp.py | 65 ++++++++++++++++++++++------------------------------ 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index dd1873f3a..e573a9433 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -520,9 +520,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): async def handle_sse_stream(self, api_format): global friendlymodelname self.send_response(200) - self.send_header("Cache-Control", "no-cache") - self.send_header("Connection", "keep-alive") - self.end_headers(force_json=True, sse_stream_flag=True) + self.send_header("cache-control", "no-cache") + self.send_header("connection", "keep-alive") + self.end_headers(content_type='text/event-stream') current_token = 0 incomplete_token_buffer = bytearray() @@ -589,10 +589,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens self.path = self.path.rstrip('/') response_body = None - force_json = False + content_type = 'application/json' if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / - + content_type = 'text/html' if self.embedded_kailite is None: response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode() else: @@ -638,9 +638,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): elif self.path.endswith('/v1/models'): response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) - force_json = True elif self.path=="/api": + content_type = 'text/html' if self.embedded_kcpp_docs is None: response_body = (f"KoboldCpp partial API reference can be found at the wiki: https://github.com/LostRuins/koboldcpp/wiki").encode() else: @@ -648,41 +648,40 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): self.path = "/api" self.send_response(302) - self.send_header("Location", self.path) - self.end_headers() + self.send_header("location", self.path) + self.end_headers(content_type='text/html') return None if response_body is None: self.send_response(404) - self.end_headers() + self.end_headers(content_type='text/html') rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' self.wfile.write(rp.encode()) else: self.send_response(200) - self.send_header('Content-Length', str(len(response_body))) - self.end_headers(force_json=force_json) + self.send_header('content-length', str(len(response_body))) + self.end_headers(content_type=content_type) self.wfile.write(response_body) return def do_POST(self): global modelbusy, requestsinqueue, currentusergenkey, totalgens - content_length = int(self.headers['Content-Length']) + content_length = int(self.headers['content-length']) body = self.rfile.read(content_length) self.path = self.path.rstrip('/') - force_json = False if self.path.endswith(('/api/extra/tokencount')): try: genparams = json.loads(body) countprompt = genparams.get('prompt', "") count = handle.token_count(countprompt.encode("UTF-8")) self.send_response(200) - self.end_headers() + self.end_headers(content_type='application/json') self.wfile.write(json.dumps({"value": count}).encode()) except ValueError as e: utfprint("Count Tokens - Body Error: " + str(e)) self.send_response(400) - self.end_headers() + self.end_headers(content_type='application/json') self.wfile.write(json.dumps({"value": -1}).encode()) return @@ -699,7 +698,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): ag = handle.abort_generate() time.sleep(0.3) #short delay before replying self.send_response(200) - self.end_headers() + self.end_headers(content_type='application/json') self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode()) print("\nGeneration Aborted") else: @@ -721,7 +720,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pendtxt = handle.get_pending_output() pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") self.send_response(200) - self.end_headers() + self.end_headers(content_type='application/json') self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode()) return @@ -731,7 +730,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): requestsinqueue += 1 if not modelbusy.acquire(blocking=reqblocking): self.send_response(503) - self.end_headers() + self.end_headers(content_type='application/json') self.wfile.write(json.dumps({"detail": { "msg": "Server is busy; please try again later.", "type": "service_unavailable", @@ -753,15 +752,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if self.path.endswith('/api/extra/generate/stream'): api_format = 2 - sse_stream_flag = True if self.path.endswith('/v1/completions'): api_format = 3 - force_json = True if self.path.endswith('/v1/chat/completions'): api_format = 4 - force_json = True if api_format > 0: genparams = None @@ -787,7 +783,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): # Headers are already sent when streaming if not sse_stream_flag: self.send_response(200) - self.end_headers(force_json=force_json) + self.end_headers(content_type='application/json') self.wfile.write(json.dumps(gen).encode()) except: print("Generate: The response could not be sent, maybe connection was terminated?") @@ -796,28 +792,23 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): modelbusy.release() self.send_response(404) - self.end_headers() + self.end_headers(content_type='text/html') def do_OPTIONS(self): self.send_response(200) - self.end_headers() + self.end_headers(content_type='text/html') def do_HEAD(self): self.send_response(200) - self.end_headers() + self.end_headers(content_type='text/html') - def end_headers(self, force_json=False, sse_stream_flag=False): - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', '*') - self.send_header('Access-Control-Allow-Headers', '*') - if ("/api" in self.path and self.path!="/api") or force_json: - if sse_stream_flag: - self.send_header('Content-Type', 'text/event-stream') - else: - self.send_header('Content-Type', 'application/json') - else: - self.send_header('Content-Type', 'text/html') + def end_headers(self, content_type=None): + self.send_header('access-control-allow-origin', '*') + self.send_header('access-control-allow-methods', '*') + self.send_header('access-control-allow-headers', '*, Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Client-Agent, X-Fields, Content-Type, Authorization, X-Requested-With, X-HTTP-Method-Override, apikey, genkey') + if content_type is not None: + self.send_header('content-type', content_type) return super(ServerRequestHandler, self).end_headers() @@ -1506,7 +1497,7 @@ def run_horde_worker(args, api_key, worker_name): if method=='POST': json_payload = json.dumps(data).encode('utf-8') request = urllib.request.Request(url, data=json_payload, headers=headers, method=method) - request.add_header('Content-Type', 'application/json') + request.add_header('content-type', 'application/json') else: request = urllib.request.Request(url, headers=headers, method=method) response_data = ""