Compare vocabs

This commit is contained in:
jaime-m-p 2024-07-09 01:04:22 +02:00
parent a943b42416
commit dec64ef793

View file

@ -112,9 +112,25 @@ class LibLlamaModel:
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
num_tokens = self.lib.llama_n_vocab(self.model)
for id in range(num_tokens):
if detokenize:
text = self.detokenize([id], remove_special=False, unparse_special=True)
else:
text = self.lib.llama_token_get_text(self.model, id)
text = self.ffi.string(text)
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
vocab.append(text)
return vocab
class Tokenizer: class Tokenizer:
def get_vocab(self, detokenize=False) -> list[str]:
raise NotImplementedError
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
raise NotImplementedError raise NotImplementedError
@ -125,7 +141,7 @@ class Tokenizer:
class TokenizerGroundtruth (Tokenizer): class TokenizerGroundtruth (Tokenizer):
def __init__(self, dir_tokenizer: str): def __init__(self, dir_tokenizer: str):
self.model = AutoTokenizer.from_pretrained(dir_tokenizer) self.model = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
# guess BOS and EOS # guess BOS and EOS
ids = self.encode("a") ids = self.encode("a")
assert 1 <= len(ids) <= 3 assert 1 <= len(ids) <= 3
@ -134,15 +150,24 @@ class TokenizerGroundtruth (Tokenizer):
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token) self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token) self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
# build vocab # build vocab
tokens = list(self.model.get_vocab().values()) self.vocab = self.get_vocab(detokenize=True)
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
self.vocab = list(sorted(self.vocab))
# tokens and lists # tokens and lists
self.special_tokens = list(self.model.all_special_tokens) self.special_tokens = list(self.model.all_special_tokens)
self.added_tokens = list(self.model.added_tokens_encoder) self.added_tokens = list(self.model.added_tokens_encoder)
self.bos_token = self.model.bos_token self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token self.eos_token = self.model.eos_token
def get_vocab(self, detokenize=False) -> list[str]:
max_token_id = max(self.model.get_vocab().values())
if detokenize:
ids = list(range(max_token_id + 1))
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
else:
vocab = [None] * (max_token_id + 1)
for text, id in self.model.get_vocab().items():
vocab[id] = text
return vocab
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
return self.model.encode(text, add_special_tokens=True) return self.model.encode(text, add_special_tokens=True)
@ -159,6 +184,9 @@ class TokenizerLlamaCpp (Tokenizer):
self.libllama = LibLlama() self.libllama = LibLlama()
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
def get_vocab(self, detokenize=False) -> list[str]:
return self.model.get_vocab(detokenize)
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
return self.model.tokenize(text, add_special=True, parse_special=True) return self.model.tokenize(text, add_special=True, parse_special=True)
@ -491,6 +519,34 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
MAX_PRINT_ERRORS = 10
logger.info("compare_vocabs: ini")
t_start = time.perf_counter()
for detokenize in (False, True):
vocab1 = tokenizer1.get_vocab(detokenize)
vocab2 = tokenizer2.get_vocab(detokenize)
if vocab1 != vocab2:
num_errors = 0
for i in range(max(len(vocab1), len(vocab2))):
text1 = vocab1[i] if i < len(vocab1) else ""
text2 = vocab2[i] if i < len(vocab2) else ""
is_unused = text1.startswith("[UNUSED_TOKEN_") # AutoTokenizer adds more unused tokens than SentencePiece ?
if text1 != text2 and is_unused and text2:
num_errors += 1
if num_errors < MAX_PRINT_ERRORS:
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
if num_errors:
logger.error(f" {num_errors=}")
t_total = time.perf_counter() - t_start
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
def main(argv: list[str] = None): def main(argv: list[str] = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", help="path to vocab 'gguf' file") parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
@ -504,13 +560,16 @@ def main(argv: list[str] = None):
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer) tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
tokenizer2 = TokenizerLlamaCpp(args.vocab_file) tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text()) compare_vocabs(tokenizer1, tokenizer2)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip()) compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe()) compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes()) # compare_tokenizers(tokenizer1, tokenizer2, generator_representative(tokenizer1))
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1)) # compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1)) # compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
# compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
# compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
# compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000)) # compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000)) # compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000)) # compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))