From bff998f8712700c3a791ae71bf36e8495a88657a Mon Sep 17 00:00:00 2001
From: Concedo <39025047+LostRuins@users.noreply.github.com>
Date: Tue, 25 Apr 2023 19:20:14 +0800
Subject: [PATCH] Slight refactor of the python code: credits to @LuxF3rre
---
koboldcpp.py | 105 ++++++++++++++++++++++-----------------------------
1 file changed, 45 insertions(+), 60 deletions(-)
diff --git a/koboldcpp.py b/koboldcpp.py
index 92f99e7d4..8310b7566 100644
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -44,9 +44,14 @@ use_blas = False # if true, uses OpenBLAS for acceleration. libopenblas.dll must
use_clblast = False #uses CLBlast instead
use_noavx2 = False #uses openblas with no avx2 instructions
+def getdirpath():
+ return os.path.dirname(os.path.realpath(__file__))
+def file_exists(filename):
+ return os.path.exists(os.path.join(getdirpath(), filename))
+
def pick_existant_file(ntoption,nonntoption):
- ntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), ntoption))
- nonntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), nonntoption))
+ ntexist = file_exists(ntoption)
+ nonntexist = file_exists(nonntoption)
if os.name == 'nt':
if nonntexist and not ntexist:
return nonntoption
@@ -80,10 +85,10 @@ def init_library():
libname = lib_default
print("Initializing dynamic library: " + libname)
- dir_path = os.path.dirname(os.path.realpath(__file__))
+ dir_path = getdirpath()
#OpenBLAS should provide about a 2x speedup on prompt ingestion if compatible.
- handle = ctypes.CDLL(os.path.join(dir_path, libname ))
+ handle = ctypes.CDLL(os.path.join(dir_path, libname))
handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool
@@ -108,7 +113,7 @@ def load_model(model_filename):
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
inputs.clblast_info = clblastids
- inputs.executable_path = (os.path.dirname(os.path.realpath(__file__))+"/").encode("UTF-8")
+ inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
ret = handle.load_model(inputs)
return ret
@@ -159,67 +164,46 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
global maxctx, maxlen, friendlymodelname, KcppVersion
- if self.path in ["/", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
- response_body = ""
+ self.path = self.path.rstrip('/')
+ response_body = None
+
+ if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
if self.embedded_kailite is None:
response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode()
else:
response_body = self.embedded_kailite
+ elif self.path.endswith(('/api/v1/model', '/api/latest/model')):
+ response_body = (json.dumps({'result': friendlymodelname }).encode())
+
+ elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
+ response_body = (json.dumps({"value": maxlen}).encode())
+
+ elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
+ response_body = (json.dumps({"value": maxctx}).encode())
+
+ elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
+ response_body = (json.dumps({"value":""}).encode())
+
+ elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
+ response_body = (json.dumps({"values": []}).encode())
+
+ elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
+ response_body = (json.dumps({"result":"1.2.2"}).encode())
+
+ elif self.path.endswith(('/api/extra/version')):
+ response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
+
+ if response_body is None:
+ self.send_response(404)
+ self.end_headers()
+ rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
+ self.wfile.write(rp.encode())
+ else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.wfile.write(response_body)
- return
-
- self.path = self.path.rstrip('/')
- if self.path.endswith(('/api/v1/model', '/api/latest/model')):
- self.send_response(200)
- self.end_headers()
- result = {'result': friendlymodelname }
- self.wfile.write(json.dumps(result).encode())
- return
-
- if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"value": maxlen}).encode())
- return
-
- if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"value": maxctx}).encode())
- return
-
- if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"value":""}).encode())
- return
-
- if self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"values": []}).encode())
- return
-
- if self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"result":"1.2.2"}).encode())
- return
-
- if self.path.endswith(('/api/extra/version')):
- self.send_response(200)
- self.end_headers()
- self.wfile.write(json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
- return
-
- self.send_response(404)
- self.end_headers()
- rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
- self.wfile.write(rp.encode())
return
def do_POST(self):
@@ -366,6 +350,7 @@ def RunServerMultiThreaded(addr, port, embedded_kailite = None):
threadArr[i].stop()
sys.exit(0)
+
def main(args):
global use_blas, use_clblast, use_noavx2
global lib_default,lib_noavx2,lib_openblas,lib_openblas_noavx2,lib_clblast
@@ -376,7 +361,7 @@ def main(args):
if args.noavx2:
use_noavx2 = True
- if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas_noavx2)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
+ if not file_exists(lib_openblas_noavx2) or (os.name=='nt' and not file_exists("libopenblas.dll")):
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
elif args.noblas:
print("Attempting to use non-avx2 compatibility library without OpenBLAS.")
@@ -384,13 +369,13 @@ def main(args):
use_blas = True
print("Attempting to use non-avx2 compatibility library with OpenBLAS. A compatible libopenblas will be required.")
elif args.useclblast:
- if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_clblast)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "clblast.dll"))):
+ if not file_exists(lib_clblast) or (os.name=='nt' and not file_exists("clblast.dll")):
print("Warning: CLBlast library file not found. Non-BLAS library will be used.")
else:
print("Attempting to use CLBlast library for faster prompt ingestion. A compatible clblast will be required.")
use_clblast = True
else:
- if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
+ if not file_exists(lib_openblas) or (os.name=='nt' and not file_exists("libopenblas.dll")):
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
elif args.noblas:
print("Attempting to library without OpenBLAS.")