apply PR suggestions
- rename -c/--cores option to -t/--threads - only use multiprocessing module when --threads > 1
This commit is contained in:
parent
08121f3aa8
commit
1041ddb2cd
1 changed files with 57 additions and 32 deletions
89
quantize.py
89
quantize.py
|
@ -3,6 +3,7 @@
|
||||||
"""Script to execute the "quantize" script on a given set of models."""
|
"""Script to execute the "quantize" script on a given set of models."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import glob
|
import glob
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
@ -10,6 +11,27 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def create_executor(threads=1):
|
||||||
|
if threads > 1:
|
||||||
|
pool = multiprocessing.Pool(threads)
|
||||||
|
|
||||||
|
def executor(func, *args):
|
||||||
|
pool.apply_async(func, args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def executor(func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield executor
|
||||||
|
finally:
|
||||||
|
if threads > 1:
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Update the quantize binary name depending on the platform and parse
|
"""Update the quantize binary name depending on the platform and parse
|
||||||
the command line arguments and execute the script.
|
the command line arguments and execute the script.
|
||||||
|
@ -49,9 +71,12 @@ def main():
|
||||||
help='Specify the path to the "quantize" script.'
|
help='Specify the path to the "quantize" script.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--cores', dest='cores',
|
'-t',
|
||||||
help='Specify the number of parallel quantization tasks',
|
'--threads',
|
||||||
default=min(4, os.cpu_count())
|
dest='threads',
|
||||||
|
type=int,
|
||||||
|
help='Specify the number of parallel quantization tasks [default=%(default)s]',
|
||||||
|
default=min(4, os.cpu_count()),
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -65,39 +90,39 @@ def main():
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
pool = multiprocessing.Pool(args.cores)
|
with create_executor(args.threads) as executor:
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
# The model is separated in various parts
|
# The model is separated in various parts
|
||||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||||
f16_model_path_base = os.path.join(
|
f16_model_path_base = os.path.join(
|
||||||
args.models_path, model, "ggml-model-f16.bin"
|
args.models_path, model, "ggml-model-f16.bin"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.isfile(f16_model_path_base):
|
if not os.path.isfile(f16_model_path_base):
|
||||||
print(f'The file %s was not found' % f16_model_path_base)
|
print(f'The file %s was not found' % f16_model_path_base)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
f16_model_parts_paths = map(
|
|
||||||
lambda filename: os.path.join(f16_model_path_base, filename),
|
|
||||||
glob.glob(f"{f16_model_path_base}*")
|
|
||||||
)
|
|
||||||
|
|
||||||
for f16_model_part_path in f16_model_parts_paths:
|
|
||||||
if not os.path.isfile(f16_model_part_path):
|
|
||||||
print(
|
|
||||||
f"The f16 model {os.path.basename(f16_model_part_path)} "
|
|
||||||
f"was not found in {args.models_path}{os.path.sep}{model}"
|
|
||||||
". If you want to use it from another location, set the "
|
|
||||||
"--models-path argument from the command line."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
pool.apply_async(
|
f16_model_parts_paths = map(
|
||||||
__run_quantize_script,
|
lambda filename: os.path.join(f16_model_path_base, filename),
|
||||||
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
|
glob.glob(f"{f16_model_path_base}*"),
|
||||||
)
|
)
|
||||||
pool.close()
|
|
||||||
pool.join()
|
for f16_model_part_path in f16_model_parts_paths:
|
||||||
|
if not os.path.isfile(f16_model_part_path):
|
||||||
|
print(
|
||||||
|
f"The f16 model {os.path.basename(f16_model_part_path)} "
|
||||||
|
f"was not found in {args.models_path}{os.path.sep}{model}"
|
||||||
|
". If you want to use it from another location, set the "
|
||||||
|
"--models-path argument from the command line."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
executor(
|
||||||
|
__run_quantize_script,
|
||||||
|
args.quantize_script_path,
|
||||||
|
f16_model_part_path,
|
||||||
|
args.remove_f16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue