py : fix outfile and outtype

This commit is contained in:
Georgi Gerganov 2024-01-09 20:40:11 +02:00
parent 787860ada2
commit 90582b7341
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1442,7 +1442,7 @@ def default_output_file(model_paths: list[Path], file_type: GGMLFileType) -> Pat
if ret in model_paths:
sys.stderr.write(
f"Error: Default output path ({ret}) would overwrite the input. "
"Please explicitly specify a path using --out-file.\n"
"Please explicitly specify a path using --outfile.\n"
)
sys.exit(1)
return ret
@ -1500,7 +1500,7 @@ def get_argument_parser() -> ArgumentParser:
)
parser.add_argument(
"--out-type",
"--outtype",
choices=output_choices,
help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)",
)
@ -1525,7 +1525,7 @@ def get_argument_parser() -> ArgumentParser:
)
parser.add_argument(
"--out-file",
"--outfile",
type=Path,
help="Specify the path for the output file (default is based on input)",
)
@ -1599,12 +1599,12 @@ def main(argv: Optional[list[str]] = None) -> None:
)
params.n_ctx = args.ctx
if args.out_type:
if args.outtype:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.out_type]
}[args.outtype]
print(f"params = {params}")
@ -1614,18 +1614,18 @@ def main(argv: Optional[list[str]] = None) -> None:
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
if args.vocab_only:
if not args.out_file:
raise ValueError("need --out-file if using --vocab-only")
out_file = args.out_file
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile
OutputFile.write_vocab_only(
out_file,
outfile,
params,
vocab,
special_vocab,
endianess=endianess,
pad_vocab=args.pad_vocab,
)
print(f"Wrote {out_file}")
print(f"Wrote {outfile}")
return
if model_plus.vocab is not None and args.vocab_dir is None:
@ -1633,15 +1633,15 @@ def main(argv: Optional[list[str]] = None) -> None:
model = model_plus.model
model = convert_model_names(model, params)
ftype = pick_output_type(model, args.out_type)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
out_file = args.out_file or default_output_file(model_plus.paths, ftype)
outfile = args.outfile or default_output_file(model_plus.paths, ftype)
params.ftype = ftype
print(f"Writing {out_file}, format {ftype}")
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(
out_file,
outfile,
ftype,
params,
model,
@ -1651,7 +1651,7 @@ def main(argv: Optional[list[str]] = None) -> None:
endianess=endianess,
pad_vocab=args.pad_vocab,
)
print(f"Wrote {out_file}")
print(f"Wrote {outfile}")
if __name__ == "__main__":