From 13b4c05d6619c2c8110bd9bef4482ff4169a7e8f Mon Sep 17 00:00:00 2001 From: InconsolableCellist Date: Tue, 28 Mar 2023 16:59:27 -0600 Subject: [PATCH] Some more code cleanup --- llama_for_kobold.py | 46 +++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/llama_for_kobold.py b/llama_for_kobold.py index 4dd33d926..1c8b17a3e 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -87,39 +87,42 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): super().__init__(*args, **kwargs) def do_GET(self): - if self.path=="/" or self.path.startswith('/?') or self.path.startswith('?'): + global maxctx, maxlen, friendlymodelname + if self.path in ["/", "/?"] or self.path.startswith('/?'): if self.embedded_kailite is None: - self.send_response(200) - self.end_headers() - self.wfile.write(b'Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.') + response_body = ( + b"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or " + b"use this URL to connect.").format(self.port).encode() else: + response_body = self.embedded_kailite + self.send_response(200) + self.send_header('Content-Length', str(len(response_body))) self.end_headers() - self.wfile.write(self.embedded_kailite) + self.wfile.write(response_body) return - if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/') or self.path.endswith('/api/v1/model') or self.path.endswith('/api/latest/model'): + self.path = self.path.rstrip('/') + if self.path.endswith(('/api/v1/model', '/api/latest/model')): self.send_response(200) self.end_headers() - global friendlymodelname - self.wfile.write(json.dumps({"result": friendlymodelname }).encode()) + result = {'result': friendlymodelname } + self.wfile.write(json.dumps(result).encode()) return - if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/') or self.path.endswith('/api/v1/config/max_length') or self.path.endswith('/api/latest/config/max_length'): + if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): self.send_response(200) self.end_headers() - global maxlen - self.wfile.write(json.dumps({"value":maxlen}).encode()) + self.wfile.write(json.dumps({"value": maxlen}).encode()) return - if self.path.endswith('/api/v1/config/max_context_length/') or self.path.endswith('/api/latest/config/max_context_length/') or self.path.endswith('/api/v1/config/max_context_length') or self.path.endswith('/api/latest/config/max_context_length'): + if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): self.send_response(200) self.end_headers() - global maxctx - self.wfile.write(json.dumps({"value":maxctx}).encode()) + self.wfile.write(json.dumps({"value": maxctx}).encode()) return - if self.path.endswith('/api/v1/config/soft_prompt') or self.path.endswith('/api/v1/config/soft_prompt/') or self.path.endswith('/api/latest/config/soft_prompt') or self.path.endswith('/api/latest/config/soft_prompt/'): + if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): self.send_response(200) self.end_headers() self.wfile.write(json.dumps({"value":""}).encode()) @@ -130,12 +133,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' self.wfile.write(rp.encode()) return - def do_POST(self): global modelbusy content_length = int(self.headers['Content-Length']) body = self.rfile.read(content_length) + basic_api_flag = False + kai_api_flag = False + self.path = self.path.rstrip('/') if modelbusy: self.send_response(503) @@ -146,11 +151,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): }}).encode()) return - basic_api_flag = False - kai_api_flag = False - if self.path.endswith('/request') or self.path.endswith('/request'): + if self.path.endswith('/request'): basic_api_flag = True - if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/') or self.path.endswith('/api/v1/generate') or self.path.endswith('/api/latest/generate'): + + if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True if basic_api_flag or kai_api_flag: @@ -170,7 +174,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): fullprompt = genparams.get('text', "") newprompt = fullprompt - recvtxt = "" if kai_api_flag: recvtxt = generate( @@ -207,7 +210,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.wfile.write(json.dumps(res).encode()) modelbusy = False return - self.send_response(404) self.end_headers()