Add VocabLoader and remove *Vocab class

This commit is contained in:
wonjun Jang 2023-10-30 11:54:59 +00:00 committed by GitHub
parent e71544231c
commit d54764d0b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -297,155 +297,24 @@ class Params:
return params
#
# vocab
#
class BpeVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
fast_tokenizer = fname_tokenizer.name == 'tokenizer.json'
tokenizer_json = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
if fast_tokenizer:
self.bpe_tokenizer = tokenizer_json['model']['vocab']
else:
self.bpe_tokenizer = tokenizer_json
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
if not fast_tokenizer:
tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
if not tokenizer_json_file.is_file():
added_tokens = {}
else:
tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
added_tokens = dict(
(item['content'], item['id'])
for item in tokenizer_json.get('added_tokens', []))
added_tokens = dict(
(token_content, token_id)
for token_content, token_id in added_tokens.items()
# Added tokens here can be duplicates of the main vocabulary.
if token_content not in self.bpe_tokenizer)
vocab_size: int = len(self.bpe_tokenizer)
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
expected_end_id = vocab_size + len(actual_ids) - 1
raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_list = [text for (text, idx) in items]
self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.bpe_tokenizer
from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import]
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()}
for i, _ in enumerate(tokenizer):
yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.bpe_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
actual_new_ids = sorted(new_tokens.keys())
if expected_new_ids != actual_new_ids:
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Token pieces that were added to the base vocabulary.
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base = vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
text: bytes = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
toktype = gguf.TokenType.NORMAL
if tokenizer.is_unknown(i):
toktype = gguf.TokenType.UNKNOWN
if tokenizer.is_control(i):
toktype = gguf.TokenType.CONTROL
# NOTE: I think added_tokens are user defined.
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
if tokenizer.is_unused(i):
toktype = gguf.TokenType.UNUSED
if tokenizer.is_byte(i):
toktype = gguf.TokenType.BYTE
yield text, score, toktype
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.sentencepiece_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class HFVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
class VocabLoader:
def __init__(self, fname_tokenizer: Path) -> None:
try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"To use HFVocab, please install the `transformers` package. "
"To use VocabLoader, please install the `transformers` package. "
"You can install it with `pip install transformers`."
) from e
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer))
vocab_set = {encoded_tok for encoded_tok, id in self.tokenizer.vocab.items()}
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
added_tokens = {
token: tid
for token, tid in self.tokenizer.get_added_vocab().items()
if token not in vocab_set
}
vocab_size: int = self.tokenizer.vocab_size
@ -459,7 +328,6 @@ class HFVocab:
self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.tokenizer
@ -478,10 +346,34 @@ class HFVocab:
yield from self.hf_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<HFVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
def get_vocab_type(self) -> str:
path_candidates = []
vocab_file = "tokenizer.model"
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
if path_candidate is not None:
return "llama"
path_candidates.append(path_candidate)
vocab_file = "vocab.json"
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
if path_candidate is not None:
return "gpt2"
path_candidates.append(path_candidate)
vocab_file = "tokenizer.json"
path_candidate = vocab_check_and_append_path(self.fname_tokenizer, vocab_file)
if path_candidate:
return "llama"
path_candidates.append(path_candidate)
raise FileNotFoundError(
f"Could not find {find_candidates} in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab | HFVocab'
def __repr__(self) -> str:
return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
Vocab: TypeAlias = 'VocabLoader'
#
# data loading
@ -854,17 +746,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
def check_vocab_size(params: Params, vocab: Vocab) -> None:
if params.n_vocab != vocab.vocab_size:
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
if params.n_vocab == vocab.vocab_size_base:
print("Ignoring added_tokens.json since model matches vocab size without it.")
vocab.added_tokens_list = []
vocab.vocab_size = vocab.vocab_size_base
return
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
if vocab.fname_added_tokens is not None:
msg += f" combined with {vocab.fname_added_tokens}"
msg += f" has {vocab.vocab_size})."
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None:
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
raise Exception(msg)
@ -911,12 +800,9 @@ class OutputFile:
scores.append(score)
toktypes.append(toktype)
if isinstance(vocab, SentencePieceVocab) or isinstance(vocab, HFVocab):
self.gguf.add_tokenizer_model("llama")
elif isinstance(vocab, BpeVocab):
self.gguf.add_tokenizer_model("gpt2")
else:
raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab or HFVocab')
vocab_type = vocab.get_vocab_type()
self.gguf.add_tokenizer_model(vocab_type)
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes)
@ -1137,50 +1023,16 @@ def vocab_check_and_append_path(path: Path, vocab_file: str) -> bool:
path = None
return path
def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
def load_vocab(path: Path) -> Vocab:
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
if path.is_dir():
find_candidates = []
vocab_file = "tokenizer.model"
if vocabtype == "bpe":
vocab_file = "vocab.json"
path_candidate = vocab_check_and_append_path(path, vocab_file)
find_candidates.append(vocab_file)
if path_candidate is None:
vocab_file = "tokenizer.json"
hf_path = vocab_check_and_append_path(path, vocab_file)
find_candidates.append(vocab_file)
if hf_path is not None:
# A case where there is no tokenizer.model but there is a tokenizer.json and it needs to be loaded into HFVocab.
vocabtype = "hf"
else:
raise FileNotFoundError(
f"Could not find {find_candidates} in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
else:
path = path_candidate
print(f"Loading vocab file '{path}'")
print(f"Loading vocab file '{path}', type '{vocabtype}'")
added_tokens_path = path.parent / "added_tokens.json"
if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm":
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "hf":
return HFVocab(path, added_tokens_path if added_tokens_path.exists() else None)
else:
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
return VocabLoader(path)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
@ -1215,7 +1067,6 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe", "hf"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
@ -1261,9 +1112,9 @@ def main(args_in: list[str] | None = None) -> None:
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
vocab = load_vocab(args.vocab_dir or args.model)
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = isinstance(vocab, BpeVocab) or isinstance(vocab, HFVocab),
load_merges = True,
n_vocab = vocab.vocab_size)
outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
@ -1274,12 +1125,15 @@ def main(args_in: list[str] | None = None) -> None:
vocab = model_plus.vocab
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype)
vocab = load_vocab(vocab_dir)
# FIXME: Try to respect vocab_dir somehow?
print(f"Vocab info: {vocab}")
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = isinstance(vocab, BpeVocab) or isinstance(vocab, HFVocab),
load_merges = True,
n_vocab = vocab.vocab_size)
print(f"Special vocab info: {special_vocab}")
model = model_plus.model
model = convert_model_names(model, params)
ftype = pick_output_type(model, args.outtype)