diff --git a/quantize.py b/quantize.py old mode 100644 new mode 100755 index 641df8dda..557e4e58f --- a/quantize.py +++ b/quantize.py @@ -2,11 +2,12 @@ """Script to execute the "quantize" script on a given set of models.""" -import subprocess import argparse import glob -import sys +import multiprocessing import os +import subprocess +import sys def main(): @@ -47,14 +48,11 @@ def main(): default=os.path.join(os.getcwd(), quantize_script_binary), help='Specify the path to the "quantize" script.' ) - - # TODO: Revise this code - # parser.add_argument( - # '-t', '--threads', dest='threads', type='int', - # default=os.cpu_count(), - # help='Specify the number of threads to use to quantize many models at ' - # 'once. Defaults to os.cpu_count().' - # ) + parser.add_argument( + '-c', '--cores', dest='cores', + help='Specify the number of parallel quantization tasks', + default=min(4, os.cpu_count()) + ) args = parser.parse_args() args.models_path = os.path.abspath(args.models_path) @@ -67,6 +65,7 @@ def main(): ) sys.exit(1) + pool = multiprocessing.Pool(args.cores) for model in args.models: # The model is separated in various parts # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) @@ -93,18 +92,15 @@ def main(): ) sys.exit(1) - __run_quantize_script( - args.quantize_script_path, f16_model_part_path + pool.apply_async( + __run_quantize_script, + (args.quantize_script_path, f16_model_part_path, args.remove_f16), ) - - if args.remove_f16: - os.remove(f16_model_part_path) + pool.close() + pool.join() -# This was extracted to a top-level function for parallelization, if -# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406 - -def __run_quantize_script(script_path, f16_model_part_path): +def __run_quantize_script(script_path, f16_model_part_path, remove_f16): """Run the quantize script specifying the path to it and the path to the f16 model to quantize. """ @@ -114,6 +110,8 @@ def __run_quantize_script(script_path, f16_model_part_path): [script_path, f16_model_part_path, new_quantized_model_path, "2"], check=True ) + if remove_f16: + os.remove(f16_model_part_path) if __name__ == "__main__":