apply PR suggestions

- rename -c/--cores option to -t/--threads
- only use multiprocessing module when --threads > 1
This commit is contained in:
Tristan Carel 2023-03-28 22:36:56 +02:00
parent 08121f3aa8
commit 1041ddb2cd

View file

@ -3,6 +3,7 @@
"""Script to execute the "quantize" script on a given set of models.""" """Script to execute the "quantize" script on a given set of models."""
import argparse import argparse
import contextlib
import glob import glob
import multiprocessing import multiprocessing
import os import os
@ -10,6 +11,27 @@ import subprocess
import sys import sys
@contextlib.contextmanager
def create_executor(threads=1):
if threads > 1:
pool = multiprocessing.Pool(threads)
def executor(func, *args):
pool.apply_async(func, args)
else:
def executor(func, *args):
return func(*args)
try:
yield executor
finally:
if threads > 1:
pool.close()
pool.join()
def main(): def main():
"""Update the quantize binary name depending on the platform and parse """Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script. the command line arguments and execute the script.
@ -49,9 +71,12 @@ def main():
help='Specify the path to the "quantize" script.' help='Specify the path to the "quantize" script.'
) )
parser.add_argument( parser.add_argument(
'-c', '--cores', dest='cores', '-t',
help='Specify the number of parallel quantization tasks', '--threads',
default=min(4, os.cpu_count()) dest='threads',
type=int,
help='Specify the number of parallel quantization tasks [default=%(default)s]',
default=min(4, os.cpu_count()),
) )
args = parser.parse_args() args = parser.parse_args()
@ -65,39 +90,39 @@ def main():
) )
sys.exit(1) sys.exit(1)
pool = multiprocessing.Pool(args.cores) with create_executor(args.threads) as executor:
for model in args.models: for model in args.models:
# The model is separated in various parts # The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join( f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin" args.models_path, model, "ggml-model-f16.bin"
) )
if not os.path.isfile(f16_model_path_base): if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base) print(f'The file %s was not found' % f16_model_path_base)
sys.exit(1)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
)
for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1) sys.exit(1)
pool.apply_async( f16_model_parts_paths = map(
__run_quantize_script, lambda filename: os.path.join(f16_model_path_base, filename),
(args.quantize_script_path, f16_model_part_path, args.remove_f16), glob.glob(f"{f16_model_path_base}*"),
) )
pool.close()
pool.join() for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1)
executor(
__run_quantize_script,
args.quantize_script_path,
f16_model_part_path,
args.remove_f16,
)
def __run_quantize_script(script_path, f16_model_part_path, remove_f16): def __run_quantize_script(script_path, f16_model_part_path, remove_f16):