fix gemma2 tokenizer convert
This commit is contained in:
parent
cb5fad4c6c
commit
62dc4a92cd
1 changed files with 18 additions and 7 deletions
|
@ -576,7 +576,7 @@ class Model:
|
|||
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_sentencepiece(self):
|
||||
def _set_vocab_sentencepiece(self, add_to_gguf=True):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||
|
@ -640,12 +640,15 @@ class Model:
|
|||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
if add_to_gguf:
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
return tokens, scores, toktypes, special_vocab
|
||||
|
||||
def _set_vocab_llama_hf(self):
|
||||
vocab = gguf.LlamaHfVocab(self.dir_model)
|
||||
|
@ -2345,7 +2348,15 @@ class Gemma2Model(Model):
|
|||
model_arch = gguf.MODEL_ARCH.GEMMA2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
tokens, scores, toktypes, special_vocab = self._set_vocab_sentencepiece(add_to_gguf=False)
|
||||
# hack: This is required so that we can properly use start/end-of-turn for chat template
|
||||
for i in range(216): # 216 -> last special token
|
||||
scores[i] = -1000.0
|
||||
toktypes[i] = SentencePieceTokenTypes.CONTROL
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue