apply PR suggestions
- rename -c/--cores option to -t/--threads - only use multiprocessing module when --threads > 1
This commit is contained in:
parent
08121f3aa8
commit
1041ddb2cd
1 changed files with 57 additions and 32 deletions
89
quantize.py
89
quantize.py
|
@ -3,6 +3,7 @@
|
|||
"""Script to execute the "quantize" script on a given set of models."""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
|
@ -10,6 +11,27 @@ import subprocess
|
|||
import sys
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_executor(threads=1):
|
||||
if threads > 1:
|
||||
pool = multiprocessing.Pool(threads)
|
||||
|
||||
def executor(func, *args):
|
||||
pool.apply_async(func, args)
|
||||
|
||||
else:
|
||||
|
||||
def executor(func, *args):
|
||||
return func(*args)
|
||||
|
||||
try:
|
||||
yield executor
|
||||
finally:
|
||||
if threads > 1:
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
def main():
|
||||
"""Update the quantize binary name depending on the platform and parse
|
||||
the command line arguments and execute the script.
|
||||
|
@ -49,9 +71,12 @@ def main():
|
|||
help='Specify the path to the "quantize" script.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--cores', dest='cores',
|
||||
help='Specify the number of parallel quantization tasks',
|
||||
default=min(4, os.cpu_count())
|
||||
'-t',
|
||||
'--threads',
|
||||
dest='threads',
|
||||
type=int,
|
||||
help='Specify the number of parallel quantization tasks [default=%(default)s]',
|
||||
default=min(4, os.cpu_count()),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -65,39 +90,39 @@ def main():
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
pool = multiprocessing.Pool(args.cores)
|
||||
for model in args.models:
|
||||
# The model is separated in various parts
|
||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||
f16_model_path_base = os.path.join(
|
||||
args.models_path, model, "ggml-model-f16.bin"
|
||||
)
|
||||
with create_executor(args.threads) as executor:
|
||||
for model in args.models:
|
||||
# The model is separated in various parts
|
||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||
f16_model_path_base = os.path.join(
|
||||
args.models_path, model, "ggml-model-f16.bin"
|
||||
)
|
||||
|
||||
if not os.path.isfile(f16_model_path_base):
|
||||
print(f'The file %s was not found' % f16_model_path_base)
|
||||
sys.exit(1)
|
||||
|
||||
f16_model_parts_paths = map(
|
||||
lambda filename: os.path.join(f16_model_path_base, filename),
|
||||
glob.glob(f"{f16_model_path_base}*")
|
||||
)
|
||||
|
||||
for f16_model_part_path in f16_model_parts_paths:
|
||||
if not os.path.isfile(f16_model_part_path):
|
||||
print(
|
||||
f"The f16 model {os.path.basename(f16_model_part_path)} "
|
||||
f"was not found in {args.models_path}{os.path.sep}{model}"
|
||||
". If you want to use it from another location, set the "
|
||||
"--models-path argument from the command line."
|
||||
)
|
||||
if not os.path.isfile(f16_model_path_base):
|
||||
print(f'The file %s was not found' % f16_model_path_base)
|
||||
sys.exit(1)
|
||||
|
||||
pool.apply_async(
|
||||
__run_quantize_script,
|
||||
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
|
||||
f16_model_parts_paths = map(
|
||||
lambda filename: os.path.join(f16_model_path_base, filename),
|
||||
glob.glob(f"{f16_model_path_base}*"),
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
for f16_model_part_path in f16_model_parts_paths:
|
||||
if not os.path.isfile(f16_model_part_path):
|
||||
print(
|
||||
f"The f16 model {os.path.basename(f16_model_part_path)} "
|
||||
f"was not found in {args.models_path}{os.path.sep}{model}"
|
||||
". If you want to use it from another location, set the "
|
||||
"--models-path argument from the command line."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
executor(
|
||||
__run_quantize_script,
|
||||
args.quantize_script_path,
|
||||
f16_model_part_path,
|
||||
args.remove_f16,
|
||||
)
|
||||
|
||||
|
||||
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue