diff --git a/llama_for_kobold.py b/llama_for_kobold.py
index 4dd33d926..1c8b17a3e 100644
--- a/llama_for_kobold.py
+++ b/llama_for_kobold.py
@@ -87,39 +87,42 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
super().__init__(*args, **kwargs)
def do_GET(self):
- if self.path=="/" or self.path.startswith('/?') or self.path.startswith('?'):
+ global maxctx, maxlen, friendlymodelname
+ if self.path in ["/", "/?"] or self.path.startswith('/?'):
if self.embedded_kailite is None:
- self.send_response(200)
- self.end_headers()
- self.wfile.write(b'Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.')
+ response_body = (
+ b"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or "
+ b"use this URL to connect.").format(self.port).encode()
else:
+ response_body = self.embedded_kailite
+
self.send_response(200)
+ self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
- self.wfile.write(self.embedded_kailite)
+ self.wfile.write(response_body)
return
- if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/') or self.path.endswith('/api/v1/model') or self.path.endswith('/api/latest/model'):
+ self.path = self.path.rstrip('/')
+ if self.path.endswith(('/api/v1/model', '/api/latest/model')):
self.send_response(200)
self.end_headers()
- global friendlymodelname
- self.wfile.write(json.dumps({"result": friendlymodelname }).encode())
+ result = {'result': friendlymodelname }
+ self.wfile.write(json.dumps(result).encode())
return
- if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/') or self.path.endswith('/api/v1/config/max_length') or self.path.endswith('/api/latest/config/max_length'):
+ if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
self.send_response(200)
self.end_headers()
- global maxlen
- self.wfile.write(json.dumps({"value":maxlen}).encode())
+ self.wfile.write(json.dumps({"value": maxlen}).encode())
return
- if self.path.endswith('/api/v1/config/max_context_length/') or self.path.endswith('/api/latest/config/max_context_length/') or self.path.endswith('/api/v1/config/max_context_length') or self.path.endswith('/api/latest/config/max_context_length'):
+ if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
self.send_response(200)
self.end_headers()
- global maxctx
- self.wfile.write(json.dumps({"value":maxctx}).encode())
+ self.wfile.write(json.dumps({"value": maxctx}).encode())
return
- if self.path.endswith('/api/v1/config/soft_prompt') or self.path.endswith('/api/v1/config/soft_prompt/') or self.path.endswith('/api/latest/config/soft_prompt') or self.path.endswith('/api/latest/config/soft_prompt/'):
+ if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"value":""}).encode())
@@ -130,12 +133,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
self.wfile.write(rp.encode())
return
-
def do_POST(self):
global modelbusy
content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
+ basic_api_flag = False
+ kai_api_flag = False
+ self.path = self.path.rstrip('/')
if modelbusy:
self.send_response(503)
@@ -146,11 +151,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
}}).encode())
return
- basic_api_flag = False
- kai_api_flag = False
- if self.path.endswith('/request') or self.path.endswith('/request'):
+ if self.path.endswith('/request'):
basic_api_flag = True
- if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/') or self.path.endswith('/api/v1/generate') or self.path.endswith('/api/latest/generate'):
+
+ if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True
if basic_api_flag or kai_api_flag:
@@ -170,7 +174,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
fullprompt = genparams.get('text', "")
newprompt = fullprompt
-
recvtxt = ""
if kai_api_flag:
recvtxt = generate(
@@ -207,7 +210,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.wfile.write(json.dumps(res).encode())
modelbusy = False
return
-
self.send_response(404)
self.end_headers()