diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 8dea75e48..b245a328b 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -11,6 +11,7 @@ from .gguf_writer import GGUFWriter class SpecialVocab: merges: list[str] + add_special_token: dict[str, bool] special_token_ids: dict[str, int] def __init__( @@ -19,6 +20,7 @@ class SpecialVocab: n_vocab: int | None = None, ): self.special_token_ids = {} + self.add_special_token = {} self.n_vocab = n_vocab self.load_merges = load_merges self.merges = [] @@ -29,13 +31,16 @@ class SpecialVocab: self._load(Path(path)) def __repr__(self) -> str: - return f'' + return f'' def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: if self.merges: if not quiet: print(f'gguf: Adding {len(self.merges)} merge(s).') gw.add_token_merges(self.merges) + elif self.load_merges: + print('gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.', + file = sys.stderr) for typ, tokid in self.special_token_ids.items(): handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) if handler is None: @@ -47,15 +52,49 @@ class SpecialVocab: if not quiet: print(f'gguf: Setting special token type {typ} to {tokid}') handler(tokid) + for typ, add in self.add_special_token.items(): + if not quiet: + print(f'gguf: Setting add special token type {typ} to {add}') + gw.add_bool(f'tokenizer.ggml.add_{typ}_token', add) def _load(self, path: Path) -> None: - if not self._try_load_from_tokenizer_json(path): - self._try_load_from_config_json(path) + self._try_load_from_tokenizer_json(path) + self._try_load_from_config_json(path) + if self.load_merges and not self.merges: + self._try_load_merges_txt(path) + + def _try_load_merges_txt(self, path: Path) -> bool: + merges_file = path / 'merges.txt' + if not merges_file.is_file(): + return False + with open(merges_file, 'r') as fp: + first_line = next(fp, '').strip() + if not first_line.startswith('#'): + fp.seek(0) + line_num = 0 + else: + line_num = 1 + merges = [] + for line in fp: + line_num += 1 + line = line.strip() + if len(line) == 0: + continue + parts = line.split(None, 3) + if len(parts) != 2: + print(f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring', + file = sys.stderr) + continue + merges.append(f'{parts[0]} {parts[1]}') + self.merges = merges + return True def _set_special_token(self, typ: str, tid: Any) -> None: if not isinstance(tid, int) or tid < 0: return if self.n_vocab is None or tid < self.n_vocab: + if typ in self.special_token_ids: + return self.special_token_ids[typ] = tid return print( @@ -80,6 +119,9 @@ class SpecialVocab: with open(tokenizer_config_file, encoding = 'utf-8') as f: tokenizer_config = json.load(f) for typ in self.special_token_types: + add_entry = tokenizer_config.get(f'add_{typ}_token') + if isinstance(add_entry, bool): + self.add_special_token[typ] = add_entry entry = tokenizer_config.get(f'{typ}_token') if isinstance(entry, str): tc_content = entry