globalize args

This commit is contained in:
Concedo 2023-08-10 16:30:02 +08:00
parent a07e6dd3ad
commit c5f5209d37

View file

@ -96,7 +96,7 @@ lib_cublas = pick_existant_file("koboldcpp_cublas.dll","koboldcpp_cublas.so")
def init_library():
global handle
global handle, args
global lib_default,lib_failsafe,lib_openblas,lib_noavx2,lib_clblast,lib_cublas
libname = ""
@ -174,6 +174,7 @@ def init_library():
handle.get_pending_output.restype = ctypes.c_char_p
def load_model(model_filename):
global args
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = 8
@ -232,7 +233,7 @@ def load_model(model_filename):
return ret
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], stream_sse=False):
global maxctx
global maxctx, args
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
@ -308,6 +309,7 @@ showdebug = True
showsamplerwarning = True
showmaxctxwarning = True
exitcounter = 0
args = None #global args
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = ""
@ -1562,7 +1564,9 @@ def run_horde_worker(args, api_key, worker_name):
time.sleep(2)
sys.exit(2)
def main(args):
def main(launch_args,start_server=True):
global args
args = launch_args
embedded_kailite = None
if not args.model_param:
args.model_param = args.model
@ -1692,8 +1696,13 @@ def main(args):
horde_thread.daemon = True
horde_thread.start()
print(f"Please connect to custom endpoint at {epurl}")
asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
if start_server:
print(f"Please connect to custom endpoint at {epurl}")
asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
else:
print(f"Server was not started, main function complete. Idling.")
# while True:
# time.sleep(5)
if __name__ == '__main__':
print("***\nWelcome to KoboldCpp - Version " + KcppVersion) # just update version manually
@ -1738,6 +1747,4 @@ if __name__ == '__main__':
parser.add_argument("--gpulayers", help="Set number of layers to offload to GPU when using GPU. Requires GPU.",metavar=('[GPU layers]'), type=int, default=0)
parser.add_argument("--tensor_split", help="For CUDA with ALL GPU set only, ratio to split tensors across multiple GPUs, space-separated list of proportions, e.g. 7 3", metavar=('[Ratios]'), type=float, nargs='+')
args = parser.parse_args()
main(args)
main(parser.parse_args(),start_server=True)