Update convert.py
This commit is contained in:
parent
f37a7d7028
commit
9f4dc236a9
1 changed files with 36 additions and 12 deletions
46
convert.py
46
convert.py
|
@ -310,8 +310,16 @@ class VocabLoader:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer))
|
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer))
|
||||||
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()}
|
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()}
|
||||||
|
|
||||||
self.added_tokens_list = []
|
self.added_tokens_list = [tok for tok in self.tokenizer.get_added_vocab()]
|
||||||
|
self.added_tokens_dict = dict(self.tokenizer.get_added_vocab())
|
||||||
|
self.added_tokens_ids = set(self.tokenizer.get_added_vocab().values())
|
||||||
|
|
||||||
self.unk_token_id = self.tokenizer.unk_token_id
|
self.unk_token_id = self.tokenizer.unk_token_id
|
||||||
|
self.specials = {
|
||||||
|
tok: self.tokenizer.vocab[tok]
|
||||||
|
for tok in self.tokenizer.all_special_tokens
|
||||||
|
}
|
||||||
|
print(self.specials)
|
||||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||||
self.vocab_size_base: int = len(vocab_set)
|
self.vocab_size_base: int = len(vocab_set)
|
||||||
self.vocab_size: int = len(vocab_set)
|
self.vocab_size: int = len(vocab_set)
|
||||||
|
@ -321,6 +329,7 @@ class VocabLoader:
|
||||||
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
|
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
|
||||||
if path_candidate is not None:
|
if path_candidate is not None:
|
||||||
self.spm = SentencePieceProcessor(str(path_candidate))
|
self.spm = SentencePieceProcessor(str(path_candidate))
|
||||||
|
print(self.spm.vocab_size(), self.vocab_size_base)
|
||||||
else:
|
else:
|
||||||
self.spm
|
self.spm
|
||||||
|
|
||||||
|
@ -330,18 +339,16 @@ class VocabLoader:
|
||||||
special_ids = set(tokenizer.all_special_ids)
|
special_ids = set(tokenizer.all_special_ids)
|
||||||
|
|
||||||
for i in range(self.vocab_size_base):
|
for i in range(self.vocab_size_base):
|
||||||
|
if i in self.added_tokens_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
text = reverse_vocab[i].encode("utf-8")
|
text = reverse_vocab[i].encode("utf-8")
|
||||||
yield text, self.get_token_score(i), self.get_token_type(i)
|
yield text, self.get_token_score(i), self.get_token_type(i)
|
||||||
|
|
||||||
def get_token_type(self, token_id):
|
def get_token_type(self, token_id, default_type=gguf.TokenType.NORMAL):
|
||||||
toktype = gguf.TokenType.NORMAL
|
toktype = default_type
|
||||||
|
|
||||||
if self.spm is None:
|
if self.spm is not None and token_id < self.spm.vocab_size():
|
||||||
if i == self.unk_token_id:
|
|
||||||
toktype = gguf.TokenType.UNKNOWN
|
|
||||||
if i in self.special_ids:
|
|
||||||
toktype = gguf.TokenType.CONTROL
|
|
||||||
else:
|
|
||||||
if self.spm.is_unknown(token_id):
|
if self.spm.is_unknown(token_id):
|
||||||
toktype = gguf.TokenType.UNKNOWN
|
toktype = gguf.TokenType.UNKNOWN
|
||||||
if self.spm.is_control(token_id):
|
if self.spm.is_control(token_id):
|
||||||
|
@ -350,18 +357,35 @@ class VocabLoader:
|
||||||
toktype = gguf.TokenType.UNUSED
|
toktype = gguf.TokenType.UNUSED
|
||||||
if self.spm.is_byte(token_id):
|
if self.spm.is_byte(token_id):
|
||||||
toktype = gguf.TokenType.BYTE
|
toktype = gguf.TokenType.BYTE
|
||||||
|
else:
|
||||||
|
if token_id == self.unk_token_id:
|
||||||
|
toktype = gguf.TokenType.UNKNOWN
|
||||||
|
if token_id in self.special_ids:
|
||||||
|
toktype = gguf.TokenType.CONTROL
|
||||||
|
|
||||||
return toktype
|
return toktype
|
||||||
|
|
||||||
def get_token_score(self, token_id):
|
def get_token_score(self, token_id):
|
||||||
if self.spm is not None:
|
if self.spm is not None and token_id < self.spm.vocab_size():
|
||||||
return self.spm.get_score(token_id)
|
return self.spm.get_score(token_id)
|
||||||
else:
|
else:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
|
default_toktype = gguf.TokenType.USER_DEFINED
|
||||||
|
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
|
|
||||||
|
if text in self.specials:
|
||||||
|
|
||||||
|
toktype = self.get_token_type(self.specials[text], default_toktype)
|
||||||
|
score = self.get_token_score(self.specials[text])
|
||||||
|
|
||||||
|
else:
|
||||||
|
toktype = default_toktype
|
||||||
score = -1000.0
|
score = -1000.0
|
||||||
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
|
||||||
|
yield text.encode("utf-8"), score, toktype
|
||||||
|
|
||||||
def has_newline_token(self):
|
def has_newline_token(self):
|
||||||
return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
|
return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue