move changes from local to BertModel
This commit is contained in:
parent
a0c5a0e82f
commit
7a5d932eaf
1 changed files with 7 additions and 7 deletions
|
@ -177,7 +177,7 @@ class Model:
|
||||||
return False
|
return False
|
||||||
return name == (key_name + suffix)
|
return name == (key_name + suffix)
|
||||||
|
|
||||||
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias", ".beta", ".gamma")) -> str:
|
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
||||||
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
raise ValueError(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
|
@ -249,9 +249,6 @@ class Model:
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name.startswith("bert."):
|
|
||||||
name = name.removeprefix("bert.")
|
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
|
@ -2194,12 +2191,15 @@ class BertModel(Model):
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
||||||
|
name = name.removeprefix("bert.")
|
||||||
|
|
||||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias") or "cls." in name:
|
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias") or "cls." in name:
|
||||||
return [] # we don't need these
|
return [] # we don't need these
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
try_suffixes = (".weight", ".bias", ".beta", ".gamma")
|
||||||
|
return [(self.map_tensor_name(name, try_suffixes), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
@Model.register("NomicBertModel")
|
@Model.register("NomicBertModel")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue