fixes to position embeddings
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
This commit is contained in:
parent
a2e03b826f
commit
d5f69e8a43
1 changed files with 23 additions and 0 deletions
|
@ -2624,6 +2624,16 @@ class BertModel(Model):
|
||||||
@Model.register("RobertaModel")
|
@Model.register("RobertaModel")
|
||||||
class RobertaModel(BertModel):
|
class RobertaModel(BertModel):
|
||||||
model_arch = gguf.MODEL_ARCH.BERT
|
model_arch = gguf.MODEL_ARCH.BERT
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# we need the pad_token_id to know how to chop down position_embd matrix
|
||||||
|
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
|
||||||
|
self._position_offset = 1 + pad_token_id
|
||||||
|
if "max_position_embeddings" in self.hparams:
|
||||||
|
self.hparams["max_position_embeddings"] -= self._position_offset
|
||||||
|
else:
|
||||||
|
self._position_offset = None
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
"""Support BPE tokenizers for roberta models"""
|
"""Support BPE tokenizers for roberta models"""
|
||||||
|
@ -2641,6 +2651,19 @@ class RobertaModel(BertModel):
|
||||||
else:
|
else:
|
||||||
return super().set_vocab()
|
return super().set_vocab()
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# if name starts with "roberta.", remove the prefix
|
||||||
|
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
|
||||||
|
if name.startswith("roberta."):
|
||||||
|
name = name[8:]
|
||||||
|
|
||||||
|
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
||||||
|
if name == "embeddings.position_embeddings.weight":
|
||||||
|
if self._position_offset is not None:
|
||||||
|
data_torch = data_torch[self._position_offset:,:]
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("NomicBertModel")
|
@Model.register("NomicBertModel")
|
||||||
class NomicBertModel(BertModel):
|
class NomicBertModel(BertModel):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue