apply PR suggestions
- rename -c/--cores option to -t/--threads - only use multiprocessing module when --threads > 1
This commit is contained in:
parent
08121f3aa8
commit
1041ddb2cd
1 changed files with 57 additions and 32 deletions
43
quantize.py
43
quantize.py
|
@ -3,6 +3,7 @@
|
||||||
"""Script to execute the "quantize" script on a given set of models."""
|
"""Script to execute the "quantize" script on a given set of models."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import glob
|
import glob
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
@ -10,6 +11,27 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def create_executor(threads=1):
|
||||||
|
if threads > 1:
|
||||||
|
pool = multiprocessing.Pool(threads)
|
||||||
|
|
||||||
|
def executor(func, *args):
|
||||||
|
pool.apply_async(func, args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def executor(func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield executor
|
||||||
|
finally:
|
||||||
|
if threads > 1:
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Update the quantize binary name depending on the platform and parse
|
"""Update the quantize binary name depending on the platform and parse
|
||||||
the command line arguments and execute the script.
|
the command line arguments and execute the script.
|
||||||
|
@ -49,9 +71,12 @@ def main():
|
||||||
help='Specify the path to the "quantize" script.'
|
help='Specify the path to the "quantize" script.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--cores', dest='cores',
|
'-t',
|
||||||
help='Specify the number of parallel quantization tasks',
|
'--threads',
|
||||||
default=min(4, os.cpu_count())
|
dest='threads',
|
||||||
|
type=int,
|
||||||
|
help='Specify the number of parallel quantization tasks [default=%(default)s]',
|
||||||
|
default=min(4, os.cpu_count()),
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -65,7 +90,7 @@ def main():
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
pool = multiprocessing.Pool(args.cores)
|
with create_executor(args.threads) as executor:
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
# The model is separated in various parts
|
# The model is separated in various parts
|
||||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||||
|
@ -79,7 +104,7 @@ def main():
|
||||||
|
|
||||||
f16_model_parts_paths = map(
|
f16_model_parts_paths = map(
|
||||||
lambda filename: os.path.join(f16_model_path_base, filename),
|
lambda filename: os.path.join(f16_model_path_base, filename),
|
||||||
glob.glob(f"{f16_model_path_base}*")
|
glob.glob(f"{f16_model_path_base}*"),
|
||||||
)
|
)
|
||||||
|
|
||||||
for f16_model_part_path in f16_model_parts_paths:
|
for f16_model_part_path in f16_model_parts_paths:
|
||||||
|
@ -92,12 +117,12 @@ def main():
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
pool.apply_async(
|
executor(
|
||||||
__run_quantize_script,
|
__run_quantize_script,
|
||||||
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
|
args.quantize_script_path,
|
||||||
|
f16_model_part_path,
|
||||||
|
args.remove_f16,
|
||||||
)
|
)
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
|
|
||||||
|
|
||||||
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue