Slight refactor of the python code: credits to @LuxF3rre
This commit is contained in:
parent
59fb174678
commit
bff998f871
1 changed files with 45 additions and 60 deletions
105
koboldcpp.py
105
koboldcpp.py
|
@ -44,9 +44,14 @@ use_blas = False # if true, uses OpenBLAS for acceleration. libopenblas.dll must
|
|||
use_clblast = False #uses CLBlast instead
|
||||
use_noavx2 = False #uses openblas with no avx2 instructions
|
||||
|
||||
def getdirpath():
|
||||
return os.path.dirname(os.path.realpath(__file__))
|
||||
def file_exists(filename):
|
||||
return os.path.exists(os.path.join(getdirpath(), filename))
|
||||
|
||||
def pick_existant_file(ntoption,nonntoption):
|
||||
ntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), ntoption))
|
||||
nonntexist = os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), nonntoption))
|
||||
ntexist = file_exists(ntoption)
|
||||
nonntexist = file_exists(nonntoption)
|
||||
if os.name == 'nt':
|
||||
if nonntexist and not ntexist:
|
||||
return nonntoption
|
||||
|
@ -80,10 +85,10 @@ def init_library():
|
|||
libname = lib_default
|
||||
|
||||
print("Initializing dynamic library: " + libname)
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
dir_path = getdirpath()
|
||||
|
||||
#OpenBLAS should provide about a 2x speedup on prompt ingestion if compatible.
|
||||
handle = ctypes.CDLL(os.path.join(dir_path, libname ))
|
||||
handle = ctypes.CDLL(os.path.join(dir_path, libname))
|
||||
|
||||
handle.load_model.argtypes = [load_model_inputs]
|
||||
handle.load_model.restype = ctypes.c_bool
|
||||
|
@ -108,7 +113,7 @@ def load_model(model_filename):
|
|||
if args.useclblast:
|
||||
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
||||
inputs.clblast_info = clblastids
|
||||
inputs.executable_path = (os.path.dirname(os.path.realpath(__file__))+"/").encode("UTF-8")
|
||||
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
|
||||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
|
@ -159,67 +164,46 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
def do_GET(self):
|
||||
global maxctx, maxlen, friendlymodelname, KcppVersion
|
||||
if self.path in ["/", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
||||
response_body = ""
|
||||
self.path = self.path.rstrip('/')
|
||||
response_body = None
|
||||
|
||||
if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
|
||||
if self.embedded_kailite is None:
|
||||
response_body = (f"Embedded Kobold Lite is not found.<br>You will have to connect via the main KoboldAI client, or <a href='https://lite.koboldai.net?local=1&port={self.port}'>use this URL</a> to connect.").encode()
|
||||
else:
|
||||
response_body = self.embedded_kailite
|
||||
|
||||
elif self.path.endswith(('/api/v1/model', '/api/latest/model')):
|
||||
response_body = (json.dumps({'result': friendlymodelname }).encode())
|
||||
|
||||
elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
|
||||
response_body = (json.dumps({"value": maxlen}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
|
||||
response_body = (json.dumps({"value": maxctx}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
|
||||
response_body = (json.dumps({"value":""}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
|
||||
response_body = (json.dumps({"values": []}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
|
||||
response_body = (json.dumps({"result":"1.2.2"}).encode())
|
||||
|
||||
elif self.path.endswith(('/api/extra/version')):
|
||||
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
|
||||
|
||||
if response_body is None:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
||||
self.wfile.write(rp.encode())
|
||||
else:
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Length', str(len(response_body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(response_body)
|
||||
return
|
||||
|
||||
self.path = self.path.rstrip('/')
|
||||
if self.path.endswith(('/api/v1/model', '/api/latest/model')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
result = {'result': friendlymodelname }
|
||||
self.wfile.write(json.dumps(result).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"value": maxlen}).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"value": maxctx}).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"value":""}).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"values": []}).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"result":"1.2.2"}).encode())
|
||||
return
|
||||
|
||||
if self.path.endswith(('/api/extra/version')):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
|
||||
return
|
||||
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
||||
self.wfile.write(rp.encode())
|
||||
return
|
||||
|
||||
def do_POST(self):
|
||||
|
@ -366,6 +350,7 @@ def RunServerMultiThreaded(addr, port, embedded_kailite = None):
|
|||
threadArr[i].stop()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(args):
|
||||
global use_blas, use_clblast, use_noavx2
|
||||
global lib_default,lib_noavx2,lib_openblas,lib_openblas_noavx2,lib_clblast
|
||||
|
@ -376,7 +361,7 @@ def main(args):
|
|||
|
||||
if args.noavx2:
|
||||
use_noavx2 = True
|
||||
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas_noavx2)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
|
||||
if not file_exists(lib_openblas_noavx2) or (os.name=='nt' and not file_exists("libopenblas.dll")):
|
||||
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
|
||||
elif args.noblas:
|
||||
print("Attempting to use non-avx2 compatibility library without OpenBLAS.")
|
||||
|
@ -384,13 +369,13 @@ def main(args):
|
|||
use_blas = True
|
||||
print("Attempting to use non-avx2 compatibility library with OpenBLAS. A compatible libopenblas will be required.")
|
||||
elif args.useclblast:
|
||||
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_clblast)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "clblast.dll"))):
|
||||
if not file_exists(lib_clblast) or (os.name=='nt' and not file_exists("clblast.dll")):
|
||||
print("Warning: CLBlast library file not found. Non-BLAS library will be used.")
|
||||
else:
|
||||
print("Attempting to use CLBlast library for faster prompt ingestion. A compatible clblast will be required.")
|
||||
use_clblast = True
|
||||
else:
|
||||
if not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), lib_openblas)) or (os.name=='nt' and not os.path.exists(os.path.join(os.path.dirname(os.path.realpath(__file__)), "libopenblas.dll"))):
|
||||
if not file_exists(lib_openblas) or (os.name=='nt' and not file_exists("libopenblas.dll")):
|
||||
print("Warning: OpenBLAS library file not found. Non-BLAS library will be used.")
|
||||
elif args.noblas:
|
||||
print("Attempting to library without OpenBLAS.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue