Get rid of ADDED_TOKENS_FILE, FAST_TOKENIZER_FILE

This commit is contained in:
Galunid 2024-05-26 19:57:02 +02:00
parent a57484ae5b
commit a72b75738b

View file

@ -15,9 +15,6 @@ from .gguf_writer import GGUFWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ADDED_TOKENS_FILE = 'added_tokens.json'
FAST_TOKENIZER_FILE = 'tokenizer.json'
class SpecialVocab: class SpecialVocab:
merges: list[str] merges: list[str]
@ -212,13 +209,13 @@ class BpeVocab(Vocab):
try: try:
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
added_tokens = json.load(f) added_tokens = json.load(f)
except FileNotFoundError: except FileNotFoundError:
pass pass
else: else:
# "fast" tokenizer # "fast" tokenizer
fname_tokenizer = base_path / FAST_TOKENIZER_FILE fname_tokenizer = base_path / 'tokenizer.json'
# if this fails, FileNotFoundError propagates to caller # if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding="utf-8") as f: with open(fname_tokenizer, encoding="utf-8") as f:
@ -282,7 +279,7 @@ class SentencePieceVocab(Vocab):
if (fname_tokenizer := base_path / 'tokenizer.model').exists(): if (fname_tokenizer := base_path / 'tokenizer.model').exists():
# normal location # normal location
try: try:
with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
added_tokens = json.load(f) added_tokens = json.load(f)
except FileNotFoundError: except FileNotFoundError:
pass pass
@ -350,7 +347,7 @@ class LlamaHfVocab(Vocab):
name = "hfft" name = "hfft"
def __init__(self, base_path: Path): def __init__(self, base_path: Path):
fname_tokenizer = base_path / FAST_TOKENIZER_FILE fname_tokenizer = base_path / 'tokenizer.json'
# if this fails, FileNotFoundError propagates to caller # if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding='utf-8') as f: with open(fname_tokenizer, encoding='utf-8') as f:
tokenizer_json = json.load(f) tokenizer_json = json.load(f)