From 8534197f14df1841df08f3b21cf7386e9c47af58 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Tue, 29 Aug 2023 04:19:18 -0600 Subject: [PATCH] gguf: SpecialVocab: Fix issue with special token content not in a dict gguf: SpecialVocab: Allow skipping handling of merges --- gguf-py/gguf/gguf.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 1f8f7098f..5f377187e 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -754,11 +754,12 @@ class GGUFWriter: class SpecialVocab: + load_merges: bool = False merges: List[str] = [] special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad')) special_token_ids: Dict[str, int] = {} - def __init__(self, path: Path, special_token_types: Optional[Tuple[str, ...]] = None): + def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None): self.special_token_ids = {} if special_token_types is not None: self.special_token_types = special_token_types @@ -774,9 +775,10 @@ class SpecialVocab: return False with open(tokenizer_file, 'r', encoding = 'utf-8') as f: tokenizer = json.load(f) - merges = tokenizer.get('model', {}).get('merges') - if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str): - self.merges = merges + if self.load_merges: + merges = tokenizer.get('model', {}).get('merges') + if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str): + self.merges = merges tokenizer_config_file = path / 'tokenizer_config.json' added_tokens = tokenizer.get('added_tokens') if added_tokens is None or not tokenizer_config_file.is_file(): @@ -784,8 +786,15 @@ class SpecialVocab: with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f: tokenizer_config = json.load(f) for typ in self.special_token_types: - tc_content = (tokenizer_config.get(f'{typ}_token') or {}).get('content') - if not isinstance(tc_content, str): + entry = tokenizer_config.get(f'{typ}_token') + if isinstance(entry, str): + tc_content = entry + elif isinstance(entry, dict): + entry_content = entry.get('content') + if not isinstance(entry_content, str): + continue + tc_content = entry_content + else: continue for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content): if isinstance(maybe_token_id, int): @@ -806,7 +815,6 @@ class SpecialVocab: return True def add_to_gguf(self, gw: GGUFWriter): - # FIXME: Don't always include merges (possibly also don't even load them). if len(self.merges) > 0: print(f'gguf: Adding {len(self.merges)} merge(s).') gw.add_token_merges(self.merges)