diff --git a/quantize.py b/quantize.py index 557e4e58f..12647940c 100755 --- a/quantize.py +++ b/quantize.py @@ -3,6 +3,7 @@ """Script to execute the "quantize" script on a given set of models.""" import argparse +import contextlib import glob import multiprocessing import os @@ -10,6 +11,27 @@ import subprocess import sys +@contextlib.contextmanager +def create_executor(threads=1): + if threads > 1: + pool = multiprocessing.Pool(threads) + + def executor(func, *args): + pool.apply_async(func, args) + + else: + + def executor(func, *args): + return func(*args) + + try: + yield executor + finally: + if threads > 1: + pool.close() + pool.join() + + def main(): """Update the quantize binary name depending on the platform and parse the command line arguments and execute the script. @@ -49,9 +71,12 @@ def main(): help='Specify the path to the "quantize" script.' ) parser.add_argument( - '-c', '--cores', dest='cores', - help='Specify the number of parallel quantization tasks', - default=min(4, os.cpu_count()) + '-t', + '--threads', + dest='threads', + type=int, + help='Specify the number of parallel quantization tasks [default=%(default)s]', + default=min(4, os.cpu_count()), ) args = parser.parse_args() @@ -65,39 +90,39 @@ def main(): ) sys.exit(1) - pool = multiprocessing.Pool(args.cores) - for model in args.models: - # The model is separated in various parts - # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) - f16_model_path_base = os.path.join( - args.models_path, model, "ggml-model-f16.bin" - ) + with create_executor(args.threads) as executor: + for model in args.models: + # The model is separated in various parts + # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) + f16_model_path_base = os.path.join( + args.models_path, model, "ggml-model-f16.bin" + ) - if not os.path.isfile(f16_model_path_base): - print(f'The file %s was not found' % f16_model_path_base) - sys.exit(1) - - f16_model_parts_paths = map( - lambda filename: os.path.join(f16_model_path_base, filename), - glob.glob(f"{f16_model_path_base}*") - ) - - for f16_model_part_path in f16_model_parts_paths: - if not os.path.isfile(f16_model_part_path): - print( - f"The f16 model {os.path.basename(f16_model_part_path)} " - f"was not found in {args.models_path}{os.path.sep}{model}" - ". If you want to use it from another location, set the " - "--models-path argument from the command line." - ) + if not os.path.isfile(f16_model_path_base): + print(f'The file %s was not found' % f16_model_path_base) sys.exit(1) - pool.apply_async( - __run_quantize_script, - (args.quantize_script_path, f16_model_part_path, args.remove_f16), + f16_model_parts_paths = map( + lambda filename: os.path.join(f16_model_path_base, filename), + glob.glob(f"{f16_model_path_base}*"), ) - pool.close() - pool.join() + + for f16_model_part_path in f16_model_parts_paths: + if not os.path.isfile(f16_model_part_path): + print( + f"The f16 model {os.path.basename(f16_model_part_path)} " + f"was not found in {args.models_path}{os.path.sep}{model}" + ". If you want to use it from another location, set the " + "--models-path argument from the command line." + ) + sys.exit(1) + + executor( + __run_quantize_script, + args.quantize_script_path, + f16_model_part_path, + args.remove_f16, + ) def __run_quantize_script(script_path, f16_model_part_path, remove_f16):