gguf-py: SpecialVocab: Always try available sources for special token ids

gguf-py: SpecialVocab: Try to load merges from merges.txt if not in tokenizer.json

gguf-py: SpecialVocab: Add 'add_bos_token' type bools to GGUF metadata
u
This commit is contained in:
KerfuffleV2 2023-11-10 05:50:45 -07:00
parent 960f912a14
commit 9ce51b69b0

View file

@ -11,6 +11,7 @@ from .gguf_writer import GGUFWriter
class SpecialVocab:
merges: list[str]
add_special_token: dict[str, bool]
special_token_ids: dict[str, int]
def __init__(
@ -19,6 +20,7 @@ class SpecialVocab:
n_vocab: int | None = None,
):
self.special_token_ids = {}
self.add_special_token = {}
self.n_vocab = n_vocab
self.load_merges = load_merges
self.merges = []
@ -29,13 +31,16 @@ class SpecialVocab:
self._load(Path(path))
def __repr__(self) -> str:
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
return f'<SpecialVocab with {len(self.merges)} merges, special tokens {self.special_token_ids or "unset"}, add special tokens {self.add_special_token or "unset"}>'
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
if self.merges:
if not quiet:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
elif self.load_merges:
print('gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
file = sys.stderr)
for typ, tokid in self.special_token_ids.items():
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
@ -47,15 +52,49 @@ class SpecialVocab:
if not quiet:
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid)
for typ, add in self.add_special_token.items():
if not quiet:
print(f'gguf: Setting add special token type {typ} to {add}')
gw.add_bool(f'tokenizer.ggml.add_{typ}_token', add)
def _load(self, path: Path) -> None:
if not self._try_load_from_tokenizer_json(path):
self._try_load_from_config_json(path)
self._try_load_from_tokenizer_json(path)
self._try_load_from_config_json(path)
if self.load_merges and not self.merges:
self._try_load_merges_txt(path)
def _try_load_merges_txt(self, path: Path) -> bool:
merges_file = path / 'merges.txt'
if not merges_file.is_file():
return False
with open(merges_file, 'r') as fp:
first_line = next(fp, '').strip()
if not first_line.startswith('#'):
fp.seek(0)
line_num = 0
else:
line_num = 1
merges = []
for line in fp:
line_num += 1
line = line.strip()
if len(line) == 0:
continue
parts = line.split(None, 3)
if len(parts) != 2:
print(f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
file = sys.stderr)
continue
merges.append(f'{parts[0]} {parts[1]}')
self.merges = merges
return True
def _set_special_token(self, typ: str, tid: Any) -> None:
if not isinstance(tid, int) or tid < 0:
return
if self.n_vocab is None or tid < self.n_vocab:
if typ in self.special_token_ids:
return
self.special_token_ids[typ] = tid
return
print(
@ -80,6 +119,9 @@ class SpecialVocab:
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
add_entry = tokenizer_config.get(f'add_{typ}_token')
if isinstance(add_entry, bool):
self.add_special_token[typ] = add_entry
entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str):
tc_content = entry