This commit is contained in:
Nathaniel Le Sage 2024-07-03 16:19:51 -07:00 committed by GitHub
commit 05ef32bd4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2261,12 +2261,18 @@ class BertModel(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
name = name.removeprefix("bert.")
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias") or "cls." in name:
return [] # we don't need these
return [(self.map_tensor_name(name), data_torch)]
name = name.replace("gamma", "weight")
name = name.replace("beta", "bias")
try_suffixes = (".weight", ".bias", ".beta", ".gamma")
return [(self.map_tensor_name(name, try_suffixes), data_torch)]
@Model.register("NomicBertModel")