Slight refactor of the python code: credits to @LuxF3rre

This commit is contained in:
Concedo 2023-04-25 19:20:14 +08:00
parent 59fb174678
commit bff998f871

View file

@ -44,9 +44,14 @@ use_blas = False # if true, uses OpenBLAS for acceleration. libopenblas.dll must
use_clblast = False #uses CLBlast instead
use_noavx2 = False #uses openblas with no avx2 instructions
def getdirpath():
return os.path.dirname(os.path.realpath(__file__))
def file_exists(filename):
return os.path.exists(os.path.join(getdirpath(), filename))
def pick_existant_file(ntoption,nonntoption):
ntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), ntoption))
nonntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), nonntoption))
ntexist = file_exists(ntoption)
nonntexist = file_exists(nonntoption)
if os.name == 'nt':
if nonntexist and not ntexist:
return nonntoption
@ -80,10 +85,10 @@ def init_library():
libname = lib_default
print("Initializing dynamic library: " + libname)
dir_path = os.path.dirname(os.path.realpath(__file__))
dir_path = getdirpath()
#OpenBLAS should provide about a 2x speedup on prompt ingestion if compatible.
handle = ctypes.CDLL(os.path.join(dir_path, libname ))
handle = ctypes.CDLL(os.path.join(dir_path, libname))
handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool
@ -108,7 +113,7 @@ def load_model(model_filename):
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
inputs.clblast_info = clblastids
inputs.executable_path = (os.path.dirname(os.path.realpath(__file__))+"/").encode("UTF-8")
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
ret = handle.load_model(inputs)
return ret
@ -159,67 +164,46 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
global maxctx, maxlen, friendlymodelname, KcppVersion
if self.path in ["/", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
response_body = ""
self.path = self.path.rstrip('/')
response_body = None
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
if self.embedded_kailite is None:
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
else:
response_body = self.embedded_kailite
elif self.path.endswith(('/api/v1/model', '/api/latest/model')):
response_body = (json.dumps({'result': friendlymodelname }).encode())
elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
response_body = (json.dumps({"value": maxlen}).encode())
elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
response_body = (json.dumps({"value": maxctx}).encode())
elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
response_body = (json.dumps({"value":""}).encode())
elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
response_body = (json.dumps({"values": []}).encode())
elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
response_body = (json.dumps({"result":"1.2.2"}).encode())
elif self.path.endswith(('/api/extra/version')):
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
if response_body is None:
self.send_response(404)
self.end_headers()
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
self.wfile.write(rp.encode())
else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.wfile.write(response_body)
return
self.path = self.path.rstrip('/')
if self.path.endswith(('/api/v1/model', '/api/latest/model')):
self.send_response(200)
self.end_headers()
result = {'result': friendlymodelname }
self.wfile.write(json.dumps(result).encode())
return
if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"value": maxlen}).encode())
return
if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"value": maxctx}).encode())
return
if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"value":""}).encode())
return
if self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"values": []}).encode())
return
if self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"result":"1.2.2"}).encode())
return
if self.path.endswith(('/api/extra/version')):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
return
self.send_response(404)
self.end_headers()
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
self.wfile.write(rp.encode())
return
def do_POST(self):
@ -366,6 +350,7 @@ def RunServerMultiThreaded(addr, port, embedded_kailite = None):
threadArr[i].stop()
sys.exit(0)
def main(args):
global use_blas, use_clblast, use_noavx2
global lib_default,lib_noavx2,lib_openblas,lib_openblas_noavx2,lib_clblast
@ -376,7 +361,7 @@ def main(args):
if args.noavx2:
use_noavx2 = True
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas_noavx2)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
if not file_exists(lib_openblas_noavx2) or (os.name=='nt' and not file_exists("libopenblas.dll")):
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
elif args.noblas:
print("Attempting to use non-avx2 compatibility library without OpenBLAS.")
@ -384,13 +369,13 @@ def main(args):
use_blas = True
print("Attempting to use non-avx2 compatibility library with OpenBLAS. A compatible libopenblas will be required.")
elif args.useclblast:
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_clblast)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "clblast.dll"))):
if not file_exists(lib_clblast) or (os.name=='nt' and not file_exists("clblast.dll")):
print("Warning: CLBlast library file not found. Non-BLAS library will be used.")
else:
print("Attempting to use CLBlast library for faster prompt ingestion. A compatible clblast will be required.")
use_clblast = True
else:
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
if not file_exists(lib_openblas) or (os.name=='nt' and not file_exists("libopenblas.dll")):
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
elif args.noblas:
print("Attempting to library without OpenBLAS.")