Fixes and improvements based on Matt's observations
Fixed and improved many things in the script based on the reviews made by @mattsta. The parallelization suggestion is still to be revised, but code for it was still added (commented).
This commit is contained in:
parent
f8db3d6cd9
commit
2ab33114de
1 changed files with 59 additions and 21 deletions
80
quantize.py
80
quantize.py
|
@ -1,4 +1,4 @@
|
||||||
#!/usr/bin/python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""Script to execute quantization on a given model."""
|
"""Script to execute quantization on a given model."""
|
||||||
|
|
||||||
|
@ -13,30 +13,76 @@ def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='Quantization Script',
|
prog='Quantization Script',
|
||||||
description='This script quantizes a model.'
|
description='This script quantizes a model or many models.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"models", nargs='+', dest='models',
|
||||||
|
choices=('7B', '13B', '30B', '65B'),
|
||||||
|
help='The models to quantize.'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("models", nargs='+', dest='models')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-r', '--remove-16', action='store_true', dest='remove_f16',
|
'-r', '--remove-16', action='store_true', dest='remove_f16',
|
||||||
help='Remove the f16 model after quantizing it.'
|
help='Remove the f16 model after quantizing it.'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-m', '--models-path', dest='models_path',
|
||||||
|
default=os.path.join(os.getcwd(), "models"),
|
||||||
|
help='Specify the directory where the models are located.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-q', '--quantize-script-path', dest='quantize_script_path',
|
||||||
|
default=os.getcwd(),
|
||||||
|
help='Specify the path to the "quantize" script.'
|
||||||
|
)
|
||||||
|
# TODO: Revise this code
|
||||||
|
# parser.add_argument(
|
||||||
|
# '-t', '--threads', dest='threads', type='int',
|
||||||
|
# default=os.cpu_count(),
|
||||||
|
# help='Specify the number of threads to use to quantize many models at '
|
||||||
|
# 'once. Defaults to os.cpu_count().'
|
||||||
|
# )
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.isfile(args.quantize_script_path):
|
||||||
|
print(
|
||||||
|
'The "quantize" script was not found in the current location.\n'
|
||||||
|
"If you want to use it from another location, set the "
|
||||||
|
"--quantize-script-path argument from the command line."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
|
f16_model_path = os.path.join(
|
||||||
model_path = os.path.join("models", model, "ggml-model-f16.bin")
|
args.models_path, model, "ggml-model-f16.bin"
|
||||||
|
|
||||||
for i in os.listdir(model_path):
|
|
||||||
subprocess.run(
|
|
||||||
["./quantize", i, i.replace("f16", "q4_0"), "2"],
|
|
||||||
shell=True,
|
|
||||||
check=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not os.path.isfile(f16_model_path):
|
||||||
|
print(
|
||||||
|
"The f16 model (ggml-model-f16.bin) was not found in "
|
||||||
|
f"models/{model}. If you want to use it from another location,"
|
||||||
|
" set the --models-path argument from the command line."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
__run_quantize_script(args.quantize_script_path, f16_model_path)
|
||||||
|
|
||||||
if args.remove_f16:
|
if args.remove_f16:
|
||||||
os.remove(i)
|
os.remove(f16_model_path)
|
||||||
|
|
||||||
|
|
||||||
|
# This was extracted to a top-level function for parallelization, if
|
||||||
|
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
|
||||||
|
def __run_quantize_script(script_path, f16_model_path):
|
||||||
|
"""Run the quantize script specifying the path to it and the path to the
|
||||||
|
f16 model to quantize.
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_quantized_model_path = f16_model_path.replace("16", "q4_0")
|
||||||
|
subprocess.run(
|
||||||
|
[script_path, f16_model_path, new_quantized_model_path, "2"],
|
||||||
|
shell=True, check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -47,13 +93,5 @@ if __name__ == "__main__":
|
||||||
print("An error ocurred while trying to quantize the models.")
|
print("An error ocurred while trying to quantize the models.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
except FileNotFoundError as err:
|
|
||||||
print(
|
|
||||||
f'A FileNotFoundError exception was raised while executing the \
|
|
||||||
script:\n{err}\nMake sure you are located in the root of the \
|
|
||||||
repository and that the models are in the "models" directory.'
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue