feat: handle gpt2 tokenizer with Jina architecture

This commit is contained in:
Joan Martinez 2024-04-24 10:05:34 +02:00
parent cde49b7448
commit dd060a2a4e
2 changed files with 19 additions and 2 deletions

View file

@ -134,6 +134,7 @@ class Model(ABC):
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@ -2370,6 +2371,7 @@ class BertModel(Model):
def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
tensors = dict(self.get_tensors())
for name, data_torch in tensors.items():
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
@ -2737,7 +2739,17 @@ class JinaBertModel(BertModel):
yield name, data
def set_vocab(self, *args, **kwargs):
super().set_vocab()
tokenizer_class = 'BertTokenizer'
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_class = json.load(f)['tokenizer_class']
if tokenizer_class == 'BertTokenizer':
super().set_vocab()
elif tokenizer_class == 'RobertaTokenizer':
self._set_vocab_gpt2()
self.gguf_writer.add_token_type_count(2)
else:
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

View file

@ -12512,7 +12512,12 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}
}
GGML_ASSERT(vocab.special_add_eos != 1);
//GGML_ASSERT(vocab.special_add_eos != 1);
//TODO: Check this, why this tokenizer does not add at the end, why not leaving up to the `gguf` exporter?
if (add_special && vocab.special_add_eos == 1) {
GGML_ASSERT(vocab.special_add_eos != -1);
output.push_back(vocab.special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_WPM:
{