model: dbrx: fix tensor names mapping broken

This commit is contained in:
Pierrick HYMBERT 2024-04-07 18:52:28 +02:00
parent f062b834ed
commit dbfd59114f

View file

@ -1538,7 +1538,7 @@ class DbrxModel(Model):
# Every other model has the weight names ending in .weight, # Every other model has the weight names ending in .weight,
# let's assume that is the convention which is not the case for dbrx: # let's assume that is the convention which is not the case for dbrx:
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15 # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=".weight") new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
if new_name is None: if new_name is None:
print(f"Can not map tensor {name!r}") print(f"Can not map tensor {name!r}")
sys.exit() sys.exit()