feat: Add prototype for identifying the vocab type

This commit is contained in:
teleprint-me 2024-05-19 22:30:37 -04:00
parent dcc5d4241d
commit c6f2a48af7
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -386,13 +386,10 @@ class Model:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
assert tokenizer.vocab_size == vocab_size
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
@ -407,6 +404,7 @@ class Model:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
tokpre = self.get_vocab_base_pre(tokenizer)
return tokens, toktypes, tokpre
# NOTE: this function is generated by convert-hf-to-gguf-update.py
@ -418,7 +416,6 @@ class Model:
# is specific for the BPE pre-tokenizer used by the model
# we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can
# use in llama.cpp to implement the same pre-tokenizer
checksum = sha256(str(tokenizer.vocab).encode()).hexdigest()
logger.debug(f"checksum: {checksum}")
@ -427,12 +424,21 @@ class Model:
# Run the `gguf-py/scripts/gguf-gen-pre.py` script to generate the checksums.
# This script should ideally pull in the latest version of the model from HuggingFace.
# DO NOT MANUALLY EDIT THIS METHOD!
models = json.load("models/checksums.json")
models = json.load(f"{tokenizer.name_or_path}/checksums.json")
for model in models:
if checksum == model["checksum"]:
logger.debug(f"tokenizer.ggml.pre: {repr(model['repo'])}")
pre = None
if model["tokt"] == gguf.TokenizerType.BPE:
pre = "bpe"
elif model["tokt"] == gguf.TokenizerType.SPM:
pre = "spm"
elif model["tokt"] == gguf.TokenizerType.WPM:
pre = "wpm"
else:
raise KeyError()
logger.debug(f"tokenizer checksum: {checksum}")
return model["tokt"] # NOTE: Use the enum to id the vocab
logger.debug(f"tokenizer.ggml.pre: {pre}")
return pre # NOTE: Use the enum to id the vocab
logger.warning("\n")
logger.warning("**************************************************************************************")