diff --git a/convert.py b/convert.py index e9b08d344..e8d50b062 100755 --- a/convert.py +++ b/convert.py @@ -359,51 +359,62 @@ class BpeVocab: class SentencePieceVocab: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + vocab_size: int = self.sentencepiece_tokenizer.vocab_size() + added_tokens: dict[str, int] if fname_added_tokens is not None: added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: added_tokens = {} - - vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) - if expected_ids != actual_ids: - raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") - items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_list = [text for (text, idx) in items] + + tokens_to_replace: dict[int, str] = {} + new_tokens: dict[int, str] = {} + for piece, idx in items: + if idx < vocab_size: + tokens_to_replace[idx] = piece + else: + new_tokens[idx] = piece + + expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) + actual_new_ids = sorted(new_tokens.keys()) + + if expected_new_ids != actual_new_ids: + raise Exception(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + + self.tokens_to_replace = tokens_to_replace + self.new_tokens_list = [new_tokens[id] for id in actual_new_ids] self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.vocab_size: int = self.vocab_size_base + len(self.new_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer - for i in range(tokenizer.vocab_size()): - piece = tokenizer.id_to_piece(i) + for id in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(id) if id not in self.tokens_to_replace else self.tokens_to_replace[id] text: bytes = piece.encode("utf-8") - score: float = tokenizer.get_score(i) + score: float = tokenizer.get_score(id) toktype = gguf.TokenType.NORMAL - if tokenizer.is_unknown(i): + if tokenizer.is_unknown(id): toktype = gguf.TokenType.UNKNOWN - if tokenizer.is_control(i): + if tokenizer.is_control(id): toktype = gguf.TokenType.CONTROL # NOTE: I think added_tokens are user defined. # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED - if tokenizer.is_unused(i): + if tokenizer.is_unused(id): toktype = gguf.TokenType.UNUSED - if tokenizer.is_byte(i): + if tokenizer.is_byte(id): toktype = gguf.TokenType.BYTE yield text, score, toktype def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - for text in self.added_tokens_list: + for text in self.new_tokens_list: score = -1000.0 yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED