apply PR suggestions

- rename -c/--cores option to -t/--threads
- only use multiprocessing module when --threads > 1
This commit is contained in:
Tristan Carel 2023-03-28 22:36:56 +02:00
parent 08121f3aa8
commit 1041ddb2cd

View file

@ -3,6 +3,7 @@
"""Script to execute the "quantize" script on a given set of models."""
import argparse
import contextlib
import glob
import multiprocessing
import os
@ -10,6 +11,27 @@ import subprocess
import sys
@contextlib.contextmanager
def create_executor(threads=1):
if threads > 1:
pool = multiprocessing.Pool(threads)
def executor(func, *args):
pool.apply_async(func, args)
else:
def executor(func, *args):
return func(*args)
try:
yield executor
finally:
if threads > 1:
pool.close()
pool.join()
def main():
"""Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script.
@ -49,9 +71,12 @@ def main():
help='Specify the path to the "quantize" script.'
)
parser.add_argument(
'-c', '--cores', dest='cores',
help='Specify the number of parallel quantization tasks',
default=min(4, os.cpu_count())
'-t',
'--threads',
dest='threads',
type=int,
help='Specify the number of parallel quantization tasks [default=%(default)s]',
default=min(4, os.cpu_count()),
)
args = parser.parse_args()
@ -65,7 +90,7 @@ def main():
)
sys.exit(1)
pool = multiprocessing.Pool(args.cores)
with create_executor(args.threads) as executor:
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
@ -79,7 +104,7 @@ def main():
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
glob.glob(f"{f16_model_path_base}*"),
)
for f16_model_part_path in f16_model_parts_paths:
@ -92,12 +117,12 @@ def main():
)
sys.exit(1)
pool.apply_async(
executor(
__run_quantize_script,
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
args.quantize_script_path,
f16_model_part_path,
args.remove_f16,
)
pool.close()
pool.join()
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):