This commit is contained in:
wheelspawn 2024-06-13 14:45:46 -05:00
parent edb1cca353
commit f5e2558f3b

View file

@ -177,7 +177,7 @@ class Model:
return False
return name == (key_name + suffix)
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias", ".beta", ".gamma")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
raise ValueError(f"Can not map tensor {name!r}")
@ -248,7 +248,10 @@ class Model:
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
if name.startswith("bert."):
name = name.removeprefix("bert.")
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
@ -2193,7 +2196,7 @@ class BertModel(Model):
del bid # unused
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias") or "cls." in name:
return [] # we don't need these
return [(self.map_tensor_name(name), data_torch)]