Small fixes to the previous commit

This commit is contained in:
SuajCarrot 2023-03-18 21:58:55 -06:00
parent 2ab33114de
commit 01237dd6f1

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Script to execute quantization on a given model.""" """Script to execute the "quantize" script on a given set of models."""
import subprocess import subprocess
import argparse import argparse
@ -13,11 +13,11 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='Quantization Script', prog='Quantization Script',
description='This script quantizes a model or many models.' description='This script quantizes the given models by applying the '
'"quantize" script on them.'
) )
parser.add_argument( parser.add_argument(
"models", nargs='+', dest='models', "models", nargs='+', choices=('7B', '13B', '30B', '65B'),
choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.' help='The models to quantize.'
) )
parser.add_argument( parser.add_argument(
@ -31,9 +31,10 @@ def main():
) )
parser.add_argument( parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path', '-q', '--quantize-script-path', dest='quantize_script_path',
default=os.getcwd(), default=os.path.join(os.getcwd(), "quantize"),
help='Specify the path to the "quantize" script.' help='Specify the path to the "quantize" script.'
) )
# TODO: Revise this code # TODO: Revise this code
# parser.add_argument( # parser.add_argument(
# '-t', '--threads', dest='threads', type='int', # '-t', '--threads', dest='threads', type='int',
@ -73,6 +74,7 @@ def main():
# This was extracted to a top-level function for parallelization, if # This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406 # implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_path): def __run_quantize_script(script_path, f16_model_path):
"""Run the quantize script specifying the path to it and the path to the """Run the quantize script specifying the path to it and the path to the
f16 model to quantize. f16 model to quantize.
@ -90,8 +92,11 @@ if __name__ == "__main__":
main() main()
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print("An error ocurred while trying to quantize the models.") print("\nAn error ocurred while trying to quantize the models.")
sys.exit(1) sys.exit(1)
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(0) sys.exit(0)
else:
print("\nSuccesfully quantized all models.")