py : fix outfile and outtype

This commit is contained in:
Georgi Gerganov 2024-01-09 20:40:11 +02:00
parent 787860ada2
commit 90582b7341
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1442,7 +1442,7 @@ def default_output_file(model_paths: list[Path], file_type: GGMLFileType) -> Pat
if ret in model_paths: if ret in model_paths:
sys.stderr.write( sys.stderr.write(
f"Error: Default output path ({ret}) would overwrite the input. " f"Error: Default output path ({ret}) would overwrite the input. "
"Please explicitly specify a path using --out-file.\n" "Please explicitly specify a path using --outfile.\n"
) )
sys.exit(1) sys.exit(1)
return ret return ret
@ -1500,7 +1500,7 @@ def get_argument_parser() -> ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--out-type", "--outtype",
choices=output_choices, choices=output_choices,
help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)",
) )
@ -1525,7 +1525,7 @@ def get_argument_parser() -> ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--out-file", "--outfile",
type=Path, type=Path,
help="Specify the path for the output file (default is based on input)", help="Specify the path for the output file (default is based on input)",
) )
@ -1599,12 +1599,12 @@ def main(argv: Optional[list[str]] = None) -> None:
) )
params.n_ctx = args.ctx params.n_ctx = args.ctx
if args.out_type: if args.outtype:
params.ftype = { params.ftype = {
"f32": GGMLFileType.AllF32, "f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16, "f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0, "q8_0": GGMLFileType.MostlyQ8_0,
}[args.out_type] }[args.outtype]
print(f"params = {params}") print(f"params = {params}")
@ -1614,18 +1614,18 @@ def main(argv: Optional[list[str]] = None) -> None:
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path) vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
if args.vocab_only: if args.vocab_only:
if not args.out_file: if not args.outfile:
raise ValueError("need --out-file if using --vocab-only") raise ValueError("need --outfile if using --vocab-only")
out_file = args.out_file outfile = args.outfile
OutputFile.write_vocab_only( OutputFile.write_vocab_only(
out_file, outfile,
params, params,
vocab, vocab,
special_vocab, special_vocab,
endianess=endianess, endianess=endianess,
pad_vocab=args.pad_vocab, pad_vocab=args.pad_vocab,
) )
print(f"Wrote {out_file}") print(f"Wrote {outfile}")
return return
if model_plus.vocab is not None and args.vocab_dir is None: if model_plus.vocab is not None and args.vocab_dir is None:
@ -1633,15 +1633,15 @@ def main(argv: Optional[list[str]] = None) -> None:
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)
ftype = pick_output_type(model, args.out_type) ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype) model = convert_to_output_type(model, ftype)
out_file = args.out_file or default_output_file(model_plus.paths, ftype) outfile = args.outfile or default_output_file(model_plus.paths, ftype)
params.ftype = ftype params.ftype = ftype
print(f"Writing {out_file}, format {ftype}") print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all( OutputFile.write_all(
out_file, outfile,
ftype, ftype,
params, params,
model, model,
@ -1651,7 +1651,7 @@ def main(argv: Optional[list[str]] = None) -> None:
endianess=endianess, endianess=endianess,
pad_vocab=args.pad_vocab, pad_vocab=args.pad_vocab,
) )
print(f"Wrote {out_file}") print(f"Wrote {outfile}")
if __name__ == "__main__": if __name__ == "__main__":