parallelize the quantization process

This commit is contained in:
Tristan Carel 2023-03-28 17:46:43 +02:00
parent 5a5f8b1501
commit 08121f3aa8

36
quantize.py Normal file → Executable file
View file

@ -2,11 +2,12 @@
"""Script to execute the "quantize" script on a given set of models."""
import subprocess
import argparse
import glob
import sys
import multiprocessing
import os
import subprocess
import sys
def main():
@ -47,14 +48,11 @@ def main():
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
)
# TODO: Revise this code
# parser.add_argument(
# '-t', '--threads', dest='threads', type='int',
# default=os.cpu_count(),
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )
parser.add_argument(
'-c', '--cores', dest='cores',
help='Specify the number of parallel quantization tasks',
default=min(4, os.cpu_count())
)
args = parser.parse_args()
args.models_path = os.path.abspath(args.models_path)
@ -67,6 +65,7 @@ def main():
)
sys.exit(1)
pool = multiprocessing.Pool(args.cores)
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
@ -93,18 +92,15 @@ def main():
)
sys.exit(1)
__run_quantize_script(
args.quantize_script_path, f16_model_part_path
pool.apply_async(
__run_quantize_script,
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
)
if args.remove_f16:
os.remove(f16_model_part_path)
pool.close()
pool.join()
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_part_path):
def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""
@ -114,6 +110,8 @@ def __run_quantize_script(script_path, f16_model_part_path):
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
)
if remove_f16:
os.remove(f16_model_part_path)
if __name__ == "__main__":