feat: handle gpt2 tokenizer with Jina architecture
This commit is contained in:
parent
cde49b7448
commit
dd060a2a4e
2 changed files with 19 additions and 2 deletions
|
@ -134,6 +134,7 @@ class Model(ABC):
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
|
||||||
for name, data_torch in self.get_tensors():
|
for name, data_torch in self.get_tensors():
|
||||||
# we don't need these
|
# we don't need these
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||||
|
@ -2370,6 +2371,7 @@ class BertModel(Model):
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
tensors = dict(self.get_tensors())
|
tensors = dict(self.get_tensors())
|
||||||
|
|
||||||
for name, data_torch in tensors.items():
|
for name, data_torch in tensors.items():
|
||||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||||
|
@ -2737,7 +2739,17 @@ class JinaBertModel(BertModel):
|
||||||
yield name, data
|
yield name, data
|
||||||
|
|
||||||
def set_vocab(self, *args, **kwargs):
|
def set_vocab(self, *args, **kwargs):
|
||||||
super().set_vocab()
|
tokenizer_class = 'BertTokenizer'
|
||||||
|
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_class = json.load(f)['tokenizer_class']
|
||||||
|
|
||||||
|
if tokenizer_class == 'BertTokenizer':
|
||||||
|
super().set_vocab()
|
||||||
|
elif tokenizer_class == 'RobertaTokenizer':
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
self.gguf_writer.add_token_type_count(2)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
|
||||||
self.gguf_writer.add_add_bos_token(True)
|
self.gguf_writer.add_add_bos_token(True)
|
||||||
self.gguf_writer.add_add_eos_token(True)
|
self.gguf_writer.add_add_eos_token(True)
|
||||||
|
|
||||||
|
|
|
@ -12512,7 +12512,12 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(vocab.special_add_eos != 1);
|
//GGML_ASSERT(vocab.special_add_eos != 1);
|
||||||
|
//TODO: Check this, why this tokenizer does not add at the end, why not leaving up to the `gguf` exporter?
|
||||||
|
if (add_special && vocab.special_add_eos == 1) {
|
||||||
|
GGML_ASSERT(vocab.special_add_eos != -1);
|
||||||
|
output.push_back(vocab.special_eos_id);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue