Update convert.py
This commit is contained in:
parent
13f07013ee
commit
f37a7d7028
1 changed files with 35 additions and 1 deletions
36
convert.py
36
convert.py
|
@ -311,9 +311,18 @@ class VocabLoader:
|
|||
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()}
|
||||
|
||||
self.added_tokens_list = []
|
||||
self.unk_token_id = self.tokenizer.unk_token_id
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
self.vocab_size_base: int = len(vocab_set)
|
||||
self.vocab_size: int = len(vocab_set)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
|
||||
vocab_file = "tokenizer.model"
|
||||
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
|
||||
if path_candidate is not None:
|
||||
self.spm = SentencePieceProcessor(str(path_candidate))
|
||||
else:
|
||||
self.spm
|
||||
|
||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
tokenizer = self.tokenizer
|
||||
|
@ -322,7 +331,32 @@ class VocabLoader:
|
|||
|
||||
for i in range(self.vocab_size_base):
|
||||
text = reverse_vocab[i].encode("utf-8")
|
||||
yield text, 0.0, gguf.TokenType.NORMAL if i not in special_ids else gguf.TokenType.CONTROL
|
||||
yield text, self.get_token_score(i), self.get_token_type(i)
|
||||
|
||||
def get_token_type(self, token_id):
|
||||
toktype = gguf.TokenType.NORMAL
|
||||
|
||||
if self.spm is None:
|
||||
if i == self.unk_token_id:
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
if i in self.special_ids:
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
else:
|
||||
if self.spm.is_unknown(token_id):
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
if self.spm.is_control(token_id):
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
if self.spm.is_unused(token_id):
|
||||
toktype = gguf.TokenType.UNUSED
|
||||
if self.spm.is_byte(token_id):
|
||||
toktype = gguf.TokenType.BYTE
|
||||
return toktype
|
||||
|
||||
def get_token_score(self, token_id):
|
||||
if self.spm is not None:
|
||||
return self.spm.get_score(token_id)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue