apply PR suggestions

- rename -c/--cores option to -t/--threads
- only use multiprocessing module when --threads > 1
This commit is contained in:
Tristan Carel 2023-03-28 22:36:56 +02:00
parent 08121f3aa8
commit 1041ddb2cd

View file

@ -3,6 +3,7 @@
"""Script to execute the "quantize" script on a given set of models.""" """Script to execute the "quantize" script on a given set of models."""
import argparse import argparse
import contextlib
import glob import glob
import multiprocessing import multiprocessing
import os import os
@ -10,6 +11,27 @@ import subprocess
import sys import sys
@contextlib.contextmanager
def create_executor(threads=1):
if threads > 1:
pool = multiprocessing.Pool(threads)
def executor(func, *args):
pool.apply_async(func, args)
else:
def executor(func, *args):
return func(*args)
try:
yield executor
finally:
if threads > 1:
pool.close()
pool.join()
def main(): def main():
"""Update the quantize binary name depending on the platform and parse """Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script. the command line arguments and execute the script.
@ -49,9 +71,12 @@ def main():
help='Specify the path to the "quantize" script.' help='Specify the path to the "quantize" script.'
) )
parser.add_argument( parser.add_argument(
'-c', '--cores', dest='cores', '-t',
help='Specify the number of parallel quantization tasks', '--threads',
default=min(4, os.cpu_count()) dest='threads',
type=int,
help='Specify the number of parallel quantization tasks [default=%(default)s]',
default=min(4, os.cpu_count()),
) )
args = parser.parse_args() args = parser.parse_args()
@ -65,7 +90,7 @@ def main():
) )
sys.exit(1) sys.exit(1)
pool = multiprocessing.Pool(args.cores) with create_executor(args.threads) as executor:
for model in args.models: for model in args.models:
# The model is separated in various parts # The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
@ -79,7 +104,7 @@ def main():
f16_model_parts_paths = map( f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename), lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*") glob.glob(f"{f16_model_path_base}*"),
) )
for f16_model_part_path in f16_model_parts_paths: for f16_model_part_path in f16_model_parts_paths:
@ -92,12 +117,12 @@ def main():
) )
sys.exit(1) sys.exit(1)
pool.apply_async( executor(
__run_quantize_script, __run_quantize_script,
(args.quantize_script_path, f16_model_part_path, args.remove_f16), args.quantize_script_path,
f16_model_part_path,
args.remove_f16,
) )
pool.close()
pool.join()
def __run_quantize_script(script_path, f16_model_part_path, remove_f16): def __run_quantize_script(script_path, f16_model_part_path, remove_f16):