improve code, fix new line issue

This commit is contained in:
ngxson 2024-07-02 00:49:06 +02:00
parent 1a99b5ec6e
commit 922a5e2939

View file

@ -577,6 +577,18 @@ class Model:
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_sentencepiece(self, add_to_gguf=True):
tokens, scores, toktypes = self._create_vocab_sentencepiece()
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def _create_vocab_sentencepiece(self):
from sentencepiece import SentencePieceProcessor
tokenizer_path = self.dir_model / 'tokenizer.model'
@ -638,17 +650,7 @@ class Model:
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
if add_to_gguf:
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab.add_to_gguf(self.gguf_writer)
return tokens, scores, toktypes, special_vocab
return tokens, scores, toktypes
def _set_vocab_llama_hf(self):
vocab = gguf.LlamaHfVocab(self.dir_model)
@ -2348,13 +2350,18 @@ class Gemma2Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA2
def set_vocab(self):
tokens, scores, toktypes, special_vocab = self._set_vocab_sentencepiece(add_to_gguf=False)
tokens, scores, toktypes = self._create_vocab_sentencepiece()
# hack: This is required so that we can properly use start/end-of-turn for chat template
for i in range(216): # 216 -> last special token
for i in range(108):
# including <unusedX>, <start_of_turn>, <end_of_turn>
toktypes[i] = SentencePieceTokenTypes.CONTROL
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_space_prefix(False)