globalize args
This commit is contained in:
parent
a07e6dd3ad
commit
c5f5209d37
1 changed files with 15 additions and 8 deletions
23
koboldcpp.py
23
koboldcpp.py
|
@ -96,7 +96,7 @@ lib_cublas = pick_existant_file("koboldcpp_cublas.dll","koboldcpp_cublas.so")
|
|||
|
||||
|
||||
def init_library():
|
||||
global handle
|
||||
global handle, args
|
||||
global lib_default,lib_failsafe,lib_openblas,lib_noavx2,lib_clblast,lib_cublas
|
||||
|
||||
libname = ""
|
||||
|
@ -174,6 +174,7 @@ def init_library():
|
|||
handle.get_pending_output.restype = ctypes.c_char_p
|
||||
|
||||
def load_model(model_filename):
|
||||
global args
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.batch_size = 8
|
||||
|
@ -232,7 +233,7 @@ def load_model(model_filename):
|
|||
return ret
|
||||
|
||||
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], stream_sse=False):
|
||||
global maxctx
|
||||
global maxctx, args
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -308,6 +309,7 @@ showdebug = True
|
|||
showsamplerwarning = True
|
||||
showmaxctxwarning = True
|
||||
exitcounter = 0
|
||||
args = None #global args
|
||||
|
||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
sys_version = ""
|
||||
|
@ -1562,7 +1564,9 @@ def run_horde_worker(args, api_key, worker_name):
|
|||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
|
||||
def main(args):
|
||||
def main(launch_args,start_server=True):
|
||||
global args
|
||||
args = launch_args
|
||||
embedded_kailite = None
|
||||
if not args.model_param:
|
||||
args.model_param = args.model
|
||||
|
@ -1692,8 +1696,13 @@ def main(args):
|
|||
horde_thread.daemon = True
|
||||
horde_thread.start()
|
||||
|
||||
print(f"Please connect to custom endpoint at {epurl}")
|
||||
asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
|
||||
if start_server:
|
||||
print(f"Please connect to custom endpoint at {epurl}")
|
||||
asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
|
||||
else:
|
||||
print(f"Server was not started, main function complete. Idling.")
|
||||
# while True:
|
||||
# time.sleep(5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("***\nWelcome to KoboldCpp - Version " + KcppVersion) # just update version manually
|
||||
|
@ -1738,6 +1747,4 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--gpulayers", help="Set number of layers to offload to GPU when using GPU. Requires GPU.",metavar=('[GPU layers]'), type=int, default=0)
|
||||
parser.add_argument("--tensor_split", help="For CUDA with ALL GPU set only, ratio to split tensors across multiple GPUs, space-separated list of proportions, e.g. 7 3", metavar=('[Ratios]'), type=float, nargs='+')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
main(parser.parse_args(),start_server=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue