This commit is contained in:
김승덕/Infrastructure그룹(YA) 2023-10-12 03:03:37 +09:00
parent 9b5907ead7
commit 576df7770a

View file

@ -366,7 +366,7 @@ class SentencePieceVocab:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else: else:
added_tokens = {} added_tokens = {}
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) items: list[tuple[str, int]] = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
tokens_to_replace: dict[int, str] = {} tokens_to_replace: dict[int, str] = {}
new_tokens: dict[int, str] = {} new_tokens: dict[int, str] = {}
@ -376,14 +376,16 @@ class SentencePieceVocab:
else: else:
new_tokens[idx] = piece new_tokens[idx] = piece
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) expected_new_ids: list[int] = list(range(vocab_size, vocab_size + len(new_tokens)))
actual_new_ids = sorted(new_tokens.keys()) actual_new_ids: list[int] = sorted(new_tokens.keys())
if expected_new_ids != actual_new_ids: if expected_new_ids != actual_new_ids:
raise Exception(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") raise Exception(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Key is the original token ID, value is the replacement token piece.
self.tokens_to_replace = tokens_to_replace self.tokens_to_replace = tokens_to_replace
self.new_tokens_list = [new_tokens[id] for id in actual_new_ids] # Token pieces that were added to the base vocabulary.
self.new_tokens_list: list[str] = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base: int = vocab_size self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.new_tokens_list) self.vocab_size: int = self.vocab_size_base + len(self.new_tokens_list)
self.fname_tokenizer = fname_tokenizer self.fname_tokenizer = fname_tokenizer