Fixes and improvements based on Matt's observations

Fixed and improved many things in the script based on the reviews made by @mattsta. The parallelization suggestion is still to be revised, but code for it was still added (commented).
This commit is contained in:
SuajCarrot 2023-03-18 21:36:40 -06:00
parent f8db3d6cd9
commit 2ab33114de

View file

@ -1,4 +1,4 @@
#!/usr/bin/python3
#!/usr/bin/env python3
"""Script to execute quantization on a given model."""
@ -13,30 +13,76 @@ def main():
parser = argparse.ArgumentParser(
prog='Quantization Script',
description='This script quantizes a model.'
description='This script quantizes a model or many models.'
)
parser.add_argument(
"models", nargs='+', dest='models',
choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
)
parser.add_argument("models", nargs='+', dest='models')
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
default=os.getcwd(),
help='Specify the path to the "quantize" script.'
)
# TODO: Revise this code
# parser.add_argument(
# '-t', '--threads', dest='threads', type='int',
# default=os.cpu_count(),
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )
args = parser.parse_args()
if not os.path.isfile(args.quantize_script_path):
print(
'The "quantize" script was not found in the current location.\n'
"If you want to use it from another location, set the "
"--quantize-script-path argument from the command line."
)
sys.exit(1)
for model in args.models:
f16_model_path = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)
model_path = os.path.join("models", model, "ggml-model-f16.bin")
for i in os.listdir(model_path):
subprocess.run(
["./quantize", i, i.replace("f16", "q4_0"), "2"],
shell=True,
check=True
if not os.path.isfile(f16_model_path):
print(
"The f16 model (ggml-model-f16.bin) was not found in "
f"models/{model}. If you want to use it from another location,"
" set the --models-path argument from the command line."
)
sys.exit(1)
if args.remove_f16:
os.remove(i)
__run_quantize_script(args.quantize_script_path, f16_model_path)
if args.remove_f16:
os.remove(f16_model_path)
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""
new_quantized_model_path = f16_model_path.replace("16", "q4_0")
subprocess.run(
[script_path, f16_model_path, new_quantized_model_path, "2"],
shell=True, check=True
)
if __name__ == "__main__":
@ -47,13 +93,5 @@ if __name__ == "__main__":
print("An error ocurred while trying to quantize the models.")
sys.exit(1)
except FileNotFoundError as err:
print(
f'A FileNotFoundError exception was raised while executing the \
script:\n{err}\nMake sure you are located in the root of the \
repository and that the models are in the "models" directory.'
)
sys.exit(1)
except KeyboardInterrupt:
sys.exit(0)