diff --git a/llama_for_kobold.py b/llama_for_kobold.py index 2e8369eb4..adf86a791 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -79,6 +79,7 @@ maxlen = 128 modelbusy = False port = 5001 last_context = "" +embedded_kailite = None class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): @@ -87,8 +88,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): def do_GET(self): if self.path=="/": - self.path = "/klite.embd" - return http.server.SimpleHTTPRequestHandler.do_GET(self) + if embedded_kailite is None: + self.send_response(200) + self.end_headers() + self.wfile.write(b'Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.') + else: + self.send_response(200) + self.end_headers() + self.wfile.write(embedded_kailite) + return if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/') or self.path.endswith('/api/v1/model') or self.path.endswith('/api/latest/model'): self.send_response(200) @@ -184,10 +192,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Headers', '*') - if "/klite.embd" in self.path: - self.send_header('Content-type', 'text/html') - else: + if "/api" in self.path: self.send_header('Content-type', 'application/json') + else: + self.send_header('Content-type', 'text/html') + return super(ServerRequestHandler, self).end_headers() @@ -263,6 +272,14 @@ if __name__ == '__main__': friendlymodelname = "concedo/llamacpp" if loadok: + try: + basepath = os.path.abspath(os.path.dirname(__file__)) + with open(basepath+"/klite.embd", mode="rb") as emb_kai: + embedded_kailite = emb_kai.read() + print("Embedded Kobold Lite loaded.") + except: + print("Could not find Kobold Lite. Embedded Kobold Lite will not be available.") + print("Starting Kobold HTTP Server on port " + str(port)) print("Please connect to custom endpoint at http://localhost:"+str(port)) RunServerMultiThreaded(port)