Review fixes

This commit is contained in:
Galunid 2023-11-07 23:14:58 +01:00
parent 73780f5939
commit 88b0d9effc

View file

@ -837,14 +837,14 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input", help="path to write to; default: based on input",
) )
parser.add_argument( parser.add_argument(
"model", type=Path, "--outtype", type=str, choices=["f32", "f16"], default="f16", nargs='?',
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=str, choices=["f32", "f16"], default="f16", nargs='?',
help="output format - use 0 for float32, 1 for float16", help="output format - use 0 for float32, 1 for float16",
) )
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
return parser.parse_args() return parser.parse_args()
@ -852,25 +852,27 @@ def parse_args() -> argparse.Namespace:
args = parse_args() args = parse_args()
dir_model = args.model dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir(): if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file=sys.stderr) print(f'Error: {args.model} is not a directory', file=sys.stderr)
sys.exit(1) sys.exit(1)
ftype_str = ["f32", "f16"] ftype = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
}
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
else: else:
# output in the same directory as the model by default # output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype}.gguf' fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
print(f"Loading model: {dir_model.name}") print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model) hparams = Model.load_hparams(dir_model)
model_class = Model.from_model_architecture(hparams["architectures"][0]) model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_str.index(ftype), fname_out, args.bigendian) model_instance = model_class(dir_model, ftype[args.outtype], fname_out, args.bigendian)
print("Set model parameters") print("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()