This commit is contained in:
Jared Van Bortel 2023-11-10 14:46:57 -05:00
parent 9ce51b69b0
commit f22b2f2045

View file

@ -31,7 +31,9 @@ class SpecialVocab:
self._load(Path(path))
def __repr__(self) -> str:
return f'<SpecialVocab with {len(self.merges)} merges, special tokens {self.special_token_ids or "unset"}, add special tokens {self.add_special_token or "unset"}>'
return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
)
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
if self.merges:
@ -39,8 +41,10 @@ class SpecialVocab:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
elif self.load_merges:
print('gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
file = sys.stderr)
print(
'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
file = sys.stderr,
)
for typ, tokid in self.special_token_ids.items():
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
@ -78,12 +82,14 @@ class SpecialVocab:
for line in fp:
line_num += 1
line = line.strip()
if len(line) == 0:
if not line:
continue
parts = line.split(None, 3)
if len(parts) != 2:
print(f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
file = sys.stderr)
print(
f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
file = sys.stderr,
)
continue
merges.append(f'{parts[0]} {parts[1]}')
self.merges = merges