diff --git a/convert.py b/convert.py index f72ed9ab5..c79318887 100755 --- a/convert.py +++ b/convert.py @@ -310,8 +310,16 @@ class VocabLoader: self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer)) vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()} - self.added_tokens_list = [] + self.added_tokens_list = [tok for tok in self.tokenizer.get_added_vocab()] + self.added_tokens_dict = dict(self.tokenizer.get_added_vocab()) + self.added_tokens_ids = set(self.tokenizer.get_added_vocab().values()) + self.unk_token_id = self.tokenizer.unk_token_id + self.specials = { + tok: self.tokenizer.vocab[tok] + for tok in self.tokenizer.all_special_tokens + } + print(self.specials) self.special_ids = set(self.tokenizer.all_special_ids) self.vocab_size_base: int = len(vocab_set) self.vocab_size: int = len(vocab_set) @@ -321,6 +329,7 @@ class VocabLoader: path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file) if path_candidate is not None: self.spm = SentencePieceProcessor(str(path_candidate)) + print(self.spm.vocab_size(), self.vocab_size_base) else: self.spm @@ -330,18 +339,16 @@ class VocabLoader: special_ids = set(tokenizer.all_special_ids) for i in range(self.vocab_size_base): + if i in self.added_tokens_ids: + continue + text = reverse_vocab[i].encode("utf-8") yield text, self.get_token_score(i), self.get_token_type(i) - def get_token_type(self, token_id): - toktype = gguf.TokenType.NORMAL + def get_token_type(self, token_id, default_type=gguf.TokenType.NORMAL): + toktype = default_type - if self.spm is None: - if i == self.unk_token_id: - toktype = gguf.TokenType.UNKNOWN - if i in self.special_ids: - toktype = gguf.TokenType.CONTROL - else: + if self.spm is not None and token_id < self.spm.vocab_size(): if self.spm.is_unknown(token_id): toktype = gguf.TokenType.UNKNOWN if self.spm.is_control(token_id): @@ -350,18 +357,35 @@ class VocabLoader: toktype = gguf.TokenType.UNUSED if self.spm.is_byte(token_id): toktype = gguf.TokenType.BYTE + else: + if token_id == self.unk_token_id: + toktype = gguf.TokenType.UNKNOWN + if token_id in self.special_ids: + toktype = gguf.TokenType.CONTROL + return toktype def get_token_score(self, token_id): - if self.spm is not None: + if self.spm is not None and token_id < self.spm.vocab_size(): return self.spm.get_score(token_id) else: return 0.0 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + default_toktype = gguf.TokenType.USER_DEFINED + for text in self.added_tokens_list: - score = -1000.0 - yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + if text in self.specials: + + toktype = self.get_token_type(self.specials[text], default_toktype) + score = self.get_token_score(self.specials[text]) + + else: + toktype = default_toktype + score = -1000.0 + + yield text.encode("utf-8"), score, toktype def has_newline_token(self): return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab