gguf: SpecialVocab: Fix issue with special token content not in a dict
gguf: SpecialVocab: Allow skipping handling of merges
This commit is contained in:
parent
4a3d783d3e
commit
8534197f14
1 changed files with 15 additions and 7 deletions
|
@ -754,11 +754,12 @@ class GGUFWriter:
|
|||
|
||||
|
||||
class SpecialVocab:
|
||||
load_merges: bool = False
|
||||
merges: List[str] = []
|
||||
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
|
||||
special_token_ids: Dict[str, int] = {}
|
||||
|
||||
def __init__(self, path: Path, special_token_types: Optional[Tuple[str, ...]] = None):
|
||||
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
|
||||
self.special_token_ids = {}
|
||||
if special_token_types is not None:
|
||||
self.special_token_types = special_token_types
|
||||
|
@ -774,9 +775,10 @@ class SpecialVocab:
|
|||
return False
|
||||
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
|
||||
tokenizer = json.load(f)
|
||||
merges = tokenizer.get('model', {}).get('merges')
|
||||
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
|
||||
self.merges = merges
|
||||
if self.load_merges:
|
||||
merges = tokenizer.get('model', {}).get('merges')
|
||||
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
|
||||
self.merges = merges
|
||||
tokenizer_config_file = path / 'tokenizer_config.json'
|
||||
added_tokens = tokenizer.get('added_tokens')
|
||||
if added_tokens is None or not tokenizer_config_file.is_file():
|
||||
|
@ -784,8 +786,15 @@ class SpecialVocab:
|
|||
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
for typ in self.special_token_types:
|
||||
tc_content = (tokenizer_config.get(f'{typ}_token') or {}).get('content')
|
||||
if not isinstance(tc_content, str):
|
||||
entry = tokenizer_config.get(f'{typ}_token')
|
||||
if isinstance(entry, str):
|
||||
tc_content = entry
|
||||
elif isinstance(entry, dict):
|
||||
entry_content = entry.get('content')
|
||||
if not isinstance(entry_content, str):
|
||||
continue
|
||||
tc_content = entry_content
|
||||
else:
|
||||
continue
|
||||
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
|
||||
if isinstance(maybe_token_id, int):
|
||||
|
@ -806,7 +815,6 @@ class SpecialVocab:
|
|||
return True
|
||||
|
||||
def add_to_gguf(self, gw: GGUFWriter):
|
||||
# FIXME: Don't always include merges (possibly also don't even load them).
|
||||
if len(self.merges) > 0:
|
||||
print(f'gguf: Adding {len(self.merges)} merge(s).')
|
||||
gw.add_token_merges(self.merges)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue