apply PR suggestions

- rename -c/--cores option to -t/--threads
- only use multiprocessing module when --threads > 1
This commit is contained in:
Tristan Carel 2023-03-28 22:36:56 +02:00
parent 08121f3aa8
commit 1041ddb2cd

View file

@ -3,6 +3,7 @@
"""Script to execute the "quantize" script on a given set of models."""
import argparse
import contextlib
import glob
import multiprocessing
import os
@ -10,6 +11,27 @@ import subprocess
import sys
@contextlib.contextmanager
def create_executor(threads=1):
if threads > 1:
pool = multiprocessing.Pool(threads)
def executor(func, *args):
pool.apply_async(func, args)
else:
def executor(func, *args):
return func(*args)
try:
yield executor
finally:
if threads > 1:
pool.close()
pool.join()
def main():
"""Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script.
@ -49,9 +71,12 @@ def main():
help='Specify the path to the "quantize" script.'
)
parser.add_argument(
'-c', '--cores', dest='cores',
help='Specify the number of parallel quantization tasks',
default=min(4, os.cpu_count())
'-t',
'--threads',
dest='threads',
type=int,
help='Specify the number of parallel quantization tasks [default=%(default)s]',
default=min(4, os.cpu_count()),
)
args = parser.parse_args()
@ -65,39 +90,39 @@ def main():
)
sys.exit(1)
pool = multiprocessing.Pool(args.cores)
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)
with create_executor(args.threads) as executor:
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)
if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base)
sys.exit(1)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
)
for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base)
sys.exit(1)
pool.apply_async(
__run_quantize_script,
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*"),
)
pool.close()
pool.join()
for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1)
executor(
__run_quantize_script,
args.quantize_script_path,
f16_model_part_path,
args.remove_f16,
)
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):