From f37a7d7028d0df1f5634984c95cd634402326244 Mon Sep 17 00:00:00 2001 From: wonjun Jang Date: Sun, 12 Nov 2023 02:22:37 +0900 Subject: [PATCH] Update convert.py --- convert.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/convert.py b/convert.py index c4eb28391..f72ed9ab5 100755 --- a/convert.py +++ b/convert.py @@ -311,9 +311,18 @@ class VocabLoader: vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()} self.added_tokens_list = [] + self.unk_token_id = self.tokenizer.unk_token_id + self.special_ids = set(self.tokenizer.all_special_ids) self.vocab_size_base: int = len(vocab_set) self.vocab_size: int = len(vocab_set) self.fname_tokenizer = fname_tokenizer + + vocab_file = "tokenizer.model" + path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file) + if path_candidate is not None: + self.spm = SentencePieceProcessor(str(path_candidate)) + else: + self.spm def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.tokenizer @@ -322,7 +331,32 @@ class VocabLoader: for i in range(self.vocab_size_base): text = reverse_vocab[i].encode("utf-8") - yield text, 0.0, gguf.TokenType.NORMAL if i not in special_ids else gguf.TokenType.CONTROL + yield text, self.get_token_score(i), self.get_token_type(i) + + def get_token_type(self, token_id): + toktype = gguf.TokenType.NORMAL + + if self.spm is None: + if i == self.unk_token_id: + toktype = gguf.TokenType.UNKNOWN + if i in self.special_ids: + toktype = gguf.TokenType.CONTROL + else: + if self.spm.is_unknown(token_id): + toktype = gguf.TokenType.UNKNOWN + if self.spm.is_control(token_id): + toktype = gguf.TokenType.CONTROL + if self.spm.is_unused(token_id): + toktype = gguf.TokenType.UNUSED + if self.spm.is_byte(token_id): + toktype = gguf.TokenType.BYTE + return toktype + + def get_token_score(self, token_id): + if self.spm is not None: + return self.spm.get_score(token_id) + else: + return 0.0 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: