gguf: SpecialVocab: Fix issue with special token content not in a dict

gguf: SpecialVocab: Allow skipping handling of merges
This commit is contained in:
KerfuffleV2 2023-08-29 04:19:18 -06:00
parent 4a3d783d3e
commit 8534197f14

View file

@ -754,11 +754,12 @@ class GGUFWriter:
class SpecialVocab:
load_merges: bool = False
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
def __init__(self, path: Path, special_token_types: Optional[Tuple[str, ...]] = None):
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
self.special_token_ids = {}
if special_token_types is not None:
self.special_token_types = special_token_types
@ -774,6 +775,7 @@ class SpecialVocab:
return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
tokenizer = json.load(f)
if self.load_merges:
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
self.merges = merges
@ -784,8 +786,15 @@ class SpecialVocab:
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
tc_content = (tokenizer_config.get(f'{typ}_token') or {}).get('content')
if not isinstance(tc_content, str):
entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str):
tc_content = entry
elif isinstance(entry, dict):
entry_content = entry.get('content')
if not isinstance(entry_content, str):
continue
tc_content = entry_content
else:
continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
if isinstance(maybe_token_id, int):
@ -806,7 +815,6 @@ class SpecialVocab:
return True
def add_to_gguf(self, gw: GGUFWriter):
# FIXME: Don't always include merges (possibly also don't even load them).
if len(self.merges) > 0:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)