parallelize the quantization process

This commit is contained in:
Tristan Carel 2023-03-28 17:46:43 +02:00
parent 5a5f8b1501
commit 08121f3aa8

36
quantize.py Normal file → Executable file
View file

@ -2,11 +2,12 @@
"""Script to execute the "quantize" script on a given set of models.""" """Script to execute the "quantize" script on a given set of models."""
import subprocess
import argparse import argparse
import glob import glob
import sys import multiprocessing
import os import os
import subprocess
import sys
def main(): def main():
@ -47,14 +48,11 @@ def main():
default=os.path.join(os.getcwd(), quantize_script_binary), default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.' help='Specify the path to the "quantize" script.'
) )
parser.add_argument(
# TODO: Revise this code '-c', '--cores', dest='cores',
# parser.add_argument( help='Specify the number of parallel quantization tasks',
# '-t', '--threads', dest='threads', type='int', default=min(4, os.cpu_count())
# default=os.cpu_count(), )
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )
args = parser.parse_args() args = parser.parse_args()
args.models_path = os.path.abspath(args.models_path) args.models_path = os.path.abspath(args.models_path)
@ -67,6 +65,7 @@ def main():
) )
sys.exit(1) sys.exit(1)
pool = multiprocessing.Pool(args.cores)
for model in args.models: for model in args.models:
# The model is separated in various parts # The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
@ -93,18 +92,15 @@ def main():
) )
sys.exit(1) sys.exit(1)
__run_quantize_script( pool.apply_async(
args.quantize_script_path, f16_model_part_path __run_quantize_script,
(args.quantize_script_path, f16_model_part_path, args.remove_f16),
) )
pool.close()
if args.remove_f16: pool.join()
os.remove(f16_model_part_path)
# This was extracted to a top-level function for parallelization, if def __run_quantize_script(script_path, f16_model_part_path, remove_f16):
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the """Run the quantize script specifying the path to it and the path to the
f16 model to quantize. f16 model to quantize.
""" """
@ -114,6 +110,8 @@ def __run_quantize_script(script_path, f16_model_part_path):
[script_path, f16_model_part_path, new_quantized_model_path, "2"], [script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True check=True
) )
if remove_f16:
os.remove(f16_model_part_path)
if __name__ == "__main__": if __name__ == "__main__":