diff --git a/convert.py b/convert.py index 8bc06120d..7a35a99a2 100644 --- a/convert.py +++ b/convert.py @@ -243,7 +243,7 @@ class Params: class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], fname_tokenizer_config: Optional[Path], vocabtype: Optional[str]) -> None: self.vocabtype = vocabtype if self.vocabtype == "bpe": self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) @@ -268,13 +268,40 @@ class SentencePieceVocab: self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - special_tokens: Dict[str, Dict[str, Any]] + self.special_tokens_map: Dict[int, str] = {} + + TOKEN_NAME_TO_ID: Dict[str, int] = { + "unk_token": self.sentencepiece_tokenizer.unk_id(), + "bos_token": self.sentencepiece_tokenizer.bos_id(), + "eos_token": self.sentencepiece_tokenizer.eos_id(), + "pad_token": self.sentencepiece_tokenizer.pad_id() + } + + tokenizer_config: Dict[str, Any] + if fname_tokenizer_config is not None: + tokenizer_config = json.load(open(fname_tokenizer_config)) + else: + tokenizer_config = {} + for key, value in tokenizer_config.items(): + assert isinstance(value, dict) or isinstance(value, str) + if key not in TOKEN_NAME_TO_ID or TOKEN_NAME_TO_ID[key] == -1: + continue + self.special_tokens_map[TOKEN_NAME_TO_ID[key]] = value["content"] if isinstance(value, dict) else value + + special_tokens: Dict[str, Any] if fname_special_tokens is not None: special_tokens = json.load(open(fname_special_tokens)) else: special_tokens = {} - token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()} - self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1} + for key, value in special_tokens.items(): + assert isinstance(value, dict) or isinstance(value, str) + if key not in TOKEN_NAME_TO_ID: + continue + token_id = TOKEN_NAME_TO_ID[key] + if token_id == -1 or token_id in self.special_tokens_map: + continue + self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value + self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map) def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: @@ -1282,7 +1309,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: special_tokens_path = path.parent / "special_tokens_map.json" tokenizer_config_path = path.parent / "tokenizer_config.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None, + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else None, tokenizer_config_path if tokenizer_config_path.exists() else None, vocabtype)