blacked quantize

This commit is contained in:
Geeks-sid 2023-03-29 19:16:48 -04:00
parent 382c0c6100
commit 172febff3f

View file

@ -25,27 +25,36 @@ def main():
quantize_script_binary = "quantize"
parser = argparse.ArgumentParser(
prog='python3 quantize.py',
description='This script quantizes the given models by applying the '
f'"{quantize_script_binary}" script on them.'
prog="python3 quantize.py",
description="This script quantizes the given models by applying the "
f'"{quantize_script_binary}" script on them.',
)
parser.add_argument(
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
"models",
nargs="+",
choices=("7B", "13B", "30B", "65B"),
help="The models to quantize.",
)
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
"-r",
"--remove-16",
action="store_true",
dest="remove_f16",
help="Remove the f16 model after quantizing it.",
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
"-m",
"--models-path",
dest="models_path",
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
help="Specify the directory where the models are located.",
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
"-q",
"--quantize-script-path",
dest="quantize_script_path",
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
help='Specify the path to the "quantize" script.',
)
# TODO: Revise this code
@ -75,12 +84,12 @@ def main():
)
if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base)
print(f"The file %s was not found" % f16_model_path_base)
sys.exit(1)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
glob.glob(f"{f16_model_path_base}*"),
)
for f16_model_part_path in f16_model_parts_paths:
@ -93,9 +102,7 @@ def main():
)
sys.exit(1)
__run_quantize_script(
args.quantize_script_path, f16_model_part_path
)
__run_quantize_script(args.quantize_script_path, f16_model_part_path)
if args.remove_f16:
os.remove(f16_model_part_path)
@ -104,6 +111,7 @@ def main():
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
@ -111,8 +119,7 @@ def __run_quantize_script(script_path, f16_model_part_path):
new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
subprocess.run(
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
[script_path, f16_model_part_path, new_quantized_model_path, "2"], check=True
)