simply ignore added tokens that id < vocab size

This commit is contained in:
김승덕/Infrastructure그룹(YA) 2023-10-12 12:44:33 +09:00
parent 576df7770a
commit e6ea63ca71

View file

@ -366,24 +366,14 @@ class SentencePieceVocab:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
items: list[tuple[str, int]] = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
tokens_to_replace: dict[int, str] = {}
new_tokens: dict[int, str] = {}
for piece, idx in items:
if idx < vocab_size:
tokens_to_replace[idx] = piece
else:
new_tokens[idx] = piece
new_tokens: dict[int, str] = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
expected_new_ids: list[int] = list(range(vocab_size, vocab_size + len(new_tokens)))
actual_new_ids: list[int] = sorted(new_tokens.keys())
if expected_new_ids != actual_new_ids:
raise Exception(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Key is the original token ID, value is the replacement token piece.
self.tokens_to_replace = tokens_to_replace
# Token pieces that were added to the base vocabulary.
self.new_tokens_list: list[str] = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base: int = vocab_size
@ -394,7 +384,7 @@ class SentencePieceVocab:
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for id in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(id) if id not in self.tokens_to_replace else self.tokens_to_replace[id]
piece = tokenizer.id_to_piece(id)
text: bytes = piece.encode("utf-8")
score: float = tokenizer.get_score(id)