Update bruteforce test:

- Compare tokenizer vocab tokens.
- Bruteforce byte token generator.
- Find minimal mismatched substring.
This commit is contained in:
jaime-m-p 2024-07-20 22:57:59 +02:00
parent 57b1d4f9eb
commit 3d16f647d1

View file

@ -116,9 +116,25 @@ class LibLlamaModel:
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
num_tokens = self.lib.llama_n_vocab(self.model)
for id in range(num_tokens):
if detokenize:
text = self.detokenize([id], remove_special=False, unparse_special=True)
else:
text = self.lib.llama_token_get_text(self.model, id)
text = self.ffi.string(text)
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
vocab.append(text)
return vocab
class Tokenizer:
def get_vocab(self, detokenize=False) -> list[str]:
raise NotImplementedError
def encode(self, text: str) -> list[int]:
raise NotImplementedError
@ -129,7 +145,7 @@ class Tokenizer:
class TokenizerGroundtruth (Tokenizer):
def __init__(self, dir_tokenizer: str):
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
# guess BOS and EOS
ids = self.encode("a")
assert 1 <= len(ids) <= 3
@ -138,15 +154,24 @@ class TokenizerGroundtruth (Tokenizer):
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
# build vocab
tokens = list(self.model.get_vocab().values())
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
self.vocab = list(sorted(self.vocab))
self.vocab = self.get_vocab(detokenize=True)
# tokens and lists
self.special_tokens = list(self.model.all_special_tokens)
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
self.special_tokens = [self.vocab[i] for i in sorted(self.model.all_special_ids)]
self.added_tokens = [self.vocab[i] for i in sorted(self.model.added_tokens_encoder.values())]
self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token
def get_vocab(self, detokenize=False) -> list[str]:
max_token_id = max(self.model.get_vocab().values())
if detokenize:
ids = list(range(max_token_id + 1))
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
else:
vocab = [None] * (max_token_id + 1)
for text, id in self.model.get_vocab().items():
vocab[id] = text
return vocab
def encode(self, text: str) -> list[int]:
return self.model.encode(text, add_special_tokens=True)
@ -163,6 +188,9 @@ class TokenizerLlamaCpp (Tokenizer):
self.libllama = LibLlama()
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
def get_vocab(self, detokenize=False) -> list[str]:
return self.model.get_vocab(detokenize)
def encode(self, text: str) -> list[int]:
return self.model.tokenize(text, add_special=True, parse_special=True)
@ -253,6 +281,23 @@ def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
yield from tokenizer.vocab
def generator_byte_tokens() -> Iterator[str]:
"""Brute force check common byte encoding"""
for a, b in ["<>", "[]", "()", ("\\", "")]:
yield from [f"{a}{i}{b}" for i in range(256)]
yield from [f"{a}{i:x}{b}" for i in range(256)]
yield from [f"{a}{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:x}{b}" for i in range(256)]
yield from [f"{a}x{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:02x}{b}" for i in range(256)]
yield from [f"{a}x{i:02X}{b}" for i in range(256)]
yield from [f"{a}0x{i:x}{b}" for i in range(256)]
yield from [f"{a}0x{i:X}{b}" for i in range(256)]
yield from [f"{a}0x{i:02x}{b}" for i in range(256)]
yield from [f"{a}0x{i:02X}{b}" for i in range(256)]
yield from [f"{a}{chr(i)}{b}" for i in range(256)]
def generator_ascii_lr_strip() -> Iterator[str]:
WHITESPACES = ["", " ", " "]
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
@ -275,10 +320,11 @@ def generator_apostrophe() -> Iterator[str]:
yield char1 + lstrip + "'" + rstrip + char2
yield char1 + char2 + lstrip + "'" + rstrip + "z"
yield "a" + lstrip + "'" + rstrip + char1 + char2
yield "a" + lstrip + "'" + char1 + char2 + rstrip + "z"
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "]
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
for token in all_tokens:
for lstrip in WHITESPACES:
@ -436,6 +482,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_start = time.perf_counter()
encode_errors = 0
decode_errors = 0
total_tests = 0
MAX_ERRORS = 10
logger.info("%s: %s" % (generator.__qualname__, "ini"))
@ -455,21 +502,44 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_encode2 += t2 - t1
t_decode1 += t3 - t2
t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
# compare
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
encode_errors += not encode_ok
decode_errors += not decode_ok
total_tests += 1
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
def _compare(text: str):
ids1 = tokenizer1.encode(text)
ids2 = tokenizer2.encode(text)
text1 = tokenizer1.decode(ids1)
text2 = tokenizer2.decode(ids1)
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
ok = encode_ok and decode_ok
return ok, ids1, ids2, text1, text2
a, b = 0, len(text)
for step in [64, 32, 16, 8, 4, 2, 1]:
while a < b:
t = max(a, b - step)
if _compare(text[a : t])[0]:
break
b = t
for step in [64, 32, 16, 8, 4, 2, 1]:
while a < b:
t = min(a + step, b)
if _compare(text[t : b])[0]:
break
a = t
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
assert a <= b and not ok
logger.error(" Text:" + repr(text[a : b]))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b]))
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
encode_errors += 1
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
@ -480,6 +550,43 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
MAX_PRINT_ERRORS = 10
logger.info("compare_vocabs: ini")
t_start = time.perf_counter()
for detokenize in (False, True):
vocab1 = tokenizer1.get_vocab(detokenize)
vocab2 = tokenizer2.get_vocab(detokenize)
if vocab1 != vocab2:
num_errors = 0
for i in range(max(len(vocab1), len(vocab2))):
text1 = vocab1[i] if i < len(vocab1) else None
text2 = vocab2[i] if i < len(vocab2) else None
if text1 != text2:
# is "[UNUSED_TOKEN_" and "[PAD" valid for all models ? #TODO: use toktypes
if text1 is not None:
text1 = text1.replace("[UNUSED_TOKEN_", "[PAD")
if text2 is not None:
text2 = text2.replace("[UNUSED_TOKEN_", "[PAD")
if text1 is None and (text2 or "").startswith('[PAD'):
text2 = None
if text2 is None and (text1 or "").startswith('[PAD'):
text1 = None
if text1 != text2:
num_errors += 1
if num_errors < MAX_PRINT_ERRORS:
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
if num_errors:
logger.error(f" {num_errors=}")
t_total = time.perf_counter() - t_start
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
@ -493,18 +600,21 @@ def main(argv: list[str] | None = None):
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_vocabs(tokenizer1, tokenizer2)
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_byte_tokens())
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
tokenizer2.model.free()