convert : add --skip-unknown CLI arg

This commit is contained in:
Georgi Gerganov 2024-02-13 20:43:45 +02:00
parent 997dd1fdf7
commit 9d166b0850
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1173,7 +1173,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
for (name, tensor) in model.items()} for (name, tensor) in model.items()}
def convert_model_names(model: LazyModel, params: Params) -> LazyModel: def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
tmap = gguf.TensorNameMap(ARCH, params.n_layer) tmap = gguf.TensorNameMap(ARCH, params.n_layer)
should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
@ -1199,9 +1199,11 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
if name_new is None: if name_new is None:
#raise Exception(f"Unexpected tensor name: {name}") if skip_unknown:
print(f"Unexpected tensor name: {name} - skipping") print(f"Unexpected tensor name: {name} - skipping")
continue continue
else:
raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
if tensor_type in should_skip: if tensor_type in should_skip:
print(f"skipping tensor {name_new}") print(f"skipping tensor {name_new}")
@ -1392,6 +1394,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY)
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.awq_path: if args.awq_path:
@ -1463,7 +1466,7 @@ def main(args_in: list[str] | None = None) -> None:
print(f"Special vocab info: {special_vocab}") print(f"Special vocab info: {special_vocab}")
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params, args.skip_unknown)
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype) model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype) outfile = args.outfile or default_outfile(model_plus.paths, ftype)