parallelize the quantization process
This commit is contained in:
parent
5a5f8b1501
commit
08121f3aa8
1 changed files with 17 additions and 19 deletions
36
quantize.py
Normal file → Executable file
36
quantize.py
Normal file → Executable file
|
@ -2,11 +2,12 @@
|
|||
|
||||
"""Script to execute the "quantize" script on a given set of models."""
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
import glob
|
||||
import sys
|
||||
import multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -47,14 +48,11 @@ def main():
|
|||
default=os.path.join(os.getcwd(), quantize_script_binary),
|
||||
help='Specify the path to the "quantize" script.'
|
||||
)
|
||||
|
||||
# TODO: Revise this code
|
||||
# parser.add_argument(
|
||||
# '-t', '--threads', dest='threads', type='int',
|
||||
# default=os.cpu_count(),
|
||||
# help='Specify the number of threads to use to quantize many models at '
|
||||
# 'once. Defaults to os.cpu_count().'
|
||||
# )
|
||||
parser.add_argument(
|
||||
'-c', '--cores', dest='cores',
|
||||
help='Specify the number of parallel quantization tasks',
|
||||
default=min(4, os.cpu_count())
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.models_path = os.path.abspath(args.models_path)
|
||||
|
@ -67,6 +65,7 @@ def main():
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
pool = multiprocessing.Pool(args.cores)
|
||||
for model in args.models:
|
||||
# The model is separated in various parts
|
||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||
|
@ -93,18 +92,15 @@ def main():
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
__run_quantize_script(
|
||||
args.quantize_script_path, f16_model_part_path
|
||||
pool.apply_async(
|
||||
__run_quantize_script,
|
||||
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
|
||||
)
|
||||
|
||||
if args.remove_f16:
|
||||
os.remove(f16_model_part_path)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
# This was extracted to a top-level function for parallelization, if
|
||||
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
|
||||
|
||||
def __run_quantize_script(script_path, f16_model_part_path):
|
||||
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
|
||||
"""Run the quantize script specifying the path to it and the path to the
|
||||
f16 model to quantize.
|
||||
"""
|
||||
|
@ -114,6 +110,8 @@ def __run_quantize_script(script_path, f16_model_part_path):
|
|||
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
|
||||
check=True
|
||||
)
|
||||
if remove_f16:
|
||||
os.remove(f16_model_part_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue