model: dbrx: fix tensor names mapping broken

This commit is contained in:
Pierrick HYMBERT 2024-04-07 18:52:28 +02:00
parent f062b834ed
commit dbfd59114f

View file

@ -1538,7 +1538,7 @@ class DbrxModel(Model):
# Every other model has the weight names ending in .weight,
# let's assume that is the convention which is not the case for dbrx:
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=".weight")
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()