Update bruteforce test:
- Compare tokenizer vocab tokens. - Bruteforce byte token generator. - Find minimal mismatched substring.
This commit is contained in:
parent
57b1d4f9eb
commit
3d16f647d1
1 changed files with 136 additions and 26 deletions
|
@ -116,9 +116,25 @@ class LibLlamaModel:
|
||||||
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
|
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
|
||||||
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||||
|
|
||||||
|
def get_vocab(self, detokenize=False) -> list[str]:
|
||||||
|
vocab: list[str] = []
|
||||||
|
num_tokens = self.lib.llama_n_vocab(self.model)
|
||||||
|
for id in range(num_tokens):
|
||||||
|
if detokenize:
|
||||||
|
text = self.detokenize([id], remove_special=False, unparse_special=True)
|
||||||
|
else:
|
||||||
|
text = self.lib.llama_token_get_text(self.model, id)
|
||||||
|
text = self.ffi.string(text)
|
||||||
|
text = str(text, encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||||
|
vocab.append(text)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
|
|
||||||
|
def get_vocab(self, detokenize=False) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode(self, text: str) -> list[int]:
|
def encode(self, text: str) -> list[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -129,7 +145,7 @@ class Tokenizer:
|
||||||
class TokenizerGroundtruth (Tokenizer):
|
class TokenizerGroundtruth (Tokenizer):
|
||||||
|
|
||||||
def __init__(self, dir_tokenizer: str):
|
def __init__(self, dir_tokenizer: str):
|
||||||
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
|
||||||
# guess BOS and EOS
|
# guess BOS and EOS
|
||||||
ids = self.encode("a")
|
ids = self.encode("a")
|
||||||
assert 1 <= len(ids) <= 3
|
assert 1 <= len(ids) <= 3
|
||||||
|
@ -138,15 +154,24 @@ class TokenizerGroundtruth (Tokenizer):
|
||||||
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
|
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
|
||||||
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
|
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
|
||||||
# build vocab
|
# build vocab
|
||||||
tokens = list(self.model.get_vocab().values())
|
self.vocab = self.get_vocab(detokenize=True)
|
||||||
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
|
|
||||||
self.vocab = list(sorted(self.vocab))
|
|
||||||
# tokens and lists
|
# tokens and lists
|
||||||
self.special_tokens = list(self.model.all_special_tokens)
|
self.special_tokens = [self.vocab[i] for i in sorted(self.model.all_special_ids)]
|
||||||
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
|
self.added_tokens = [self.vocab[i] for i in sorted(self.model.added_tokens_encoder.values())]
|
||||||
self.bos_token = self.model.bos_token
|
self.bos_token = self.model.bos_token
|
||||||
self.eos_token = self.model.eos_token
|
self.eos_token = self.model.eos_token
|
||||||
|
|
||||||
|
def get_vocab(self, detokenize=False) -> list[str]:
|
||||||
|
max_token_id = max(self.model.get_vocab().values())
|
||||||
|
if detokenize:
|
||||||
|
ids = list(range(max_token_id + 1))
|
||||||
|
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
|
||||||
|
else:
|
||||||
|
vocab = [None] * (max_token_id + 1)
|
||||||
|
for text, id in self.model.get_vocab().items():
|
||||||
|
vocab[id] = text
|
||||||
|
return vocab
|
||||||
|
|
||||||
def encode(self, text: str) -> list[int]:
|
def encode(self, text: str) -> list[int]:
|
||||||
return self.model.encode(text, add_special_tokens=True)
|
return self.model.encode(text, add_special_tokens=True)
|
||||||
|
|
||||||
|
@ -163,6 +188,9 @@ class TokenizerLlamaCpp (Tokenizer):
|
||||||
self.libllama = LibLlama()
|
self.libllama = LibLlama()
|
||||||
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||||
|
|
||||||
|
def get_vocab(self, detokenize=False) -> list[str]:
|
||||||
|
return self.model.get_vocab(detokenize)
|
||||||
|
|
||||||
def encode(self, text: str) -> list[int]:
|
def encode(self, text: str) -> list[int]:
|
||||||
return self.model.tokenize(text, add_special=True, parse_special=True)
|
return self.model.tokenize(text, add_special=True, parse_special=True)
|
||||||
|
|
||||||
|
@ -253,6 +281,23 @@ def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
||||||
yield from tokenizer.vocab
|
yield from tokenizer.vocab
|
||||||
|
|
||||||
|
|
||||||
|
def generator_byte_tokens() -> Iterator[str]:
|
||||||
|
"""Brute force check common byte encoding"""
|
||||||
|
for a, b in ["<>", "[]", "()", ("\\", "")]:
|
||||||
|
yield from [f"{a}{i}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}{i:x}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}{i:X}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}x{i:x}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}x{i:X}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}x{i:02x}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}x{i:02X}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}0x{i:x}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}0x{i:X}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}0x{i:02x}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}0x{i:02X}{b}" for i in range(256)]
|
||||||
|
yield from [f"{a}{chr(i)}{b}" for i in range(256)]
|
||||||
|
|
||||||
|
|
||||||
def generator_ascii_lr_strip() -> Iterator[str]:
|
def generator_ascii_lr_strip() -> Iterator[str]:
|
||||||
WHITESPACES = ["", " ", " "]
|
WHITESPACES = ["", " ", " "]
|
||||||
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
|
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
|
||||||
|
@ -275,10 +320,11 @@ def generator_apostrophe() -> Iterator[str]:
|
||||||
yield char1 + lstrip + "'" + rstrip + char2
|
yield char1 + lstrip + "'" + rstrip + char2
|
||||||
yield char1 + char2 + lstrip + "'" + rstrip + "z"
|
yield char1 + char2 + lstrip + "'" + rstrip + "z"
|
||||||
yield "a" + lstrip + "'" + rstrip + char1 + char2
|
yield "a" + lstrip + "'" + rstrip + char1 + char2
|
||||||
|
yield "a" + lstrip + "'" + char1 + char2 + rstrip + "z"
|
||||||
|
|
||||||
|
|
||||||
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
||||||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
|
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "]
|
||||||
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
||||||
for token in all_tokens:
|
for token in all_tokens:
|
||||||
for lstrip in WHITESPACES:
|
for lstrip in WHITESPACES:
|
||||||
|
@ -436,6 +482,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
encode_errors = 0
|
encode_errors = 0
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
total_tests = 0
|
||||||
MAX_ERRORS = 10
|
MAX_ERRORS = 10
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
||||||
|
@ -455,21 +502,44 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_encode2 += t2 - t1
|
t_encode2 += t2 - t1
|
||||||
t_decode1 += t3 - t2
|
t_decode1 += t3 - t2
|
||||||
t_decode2 += t4 - t3
|
t_decode2 += t4 - t3
|
||||||
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
# compare
|
||||||
i = find_first_mismatch(ids1, ids2)
|
encode_ok = ids1 == ids2
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
encode_errors += not encode_ok
|
||||||
|
decode_errors += not decode_ok
|
||||||
|
total_tests += 1
|
||||||
|
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
|
||||||
|
def _compare(text: str):
|
||||||
|
ids1 = tokenizer1.encode(text)
|
||||||
|
ids2 = tokenizer2.encode(text)
|
||||||
|
text1 = tokenizer1.decode(ids1)
|
||||||
|
text2 = tokenizer2.decode(ids1)
|
||||||
|
encode_ok = ids1 == ids2
|
||||||
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
|
ok = encode_ok and decode_ok
|
||||||
|
return ok, ids1, ids2, text1, text2
|
||||||
|
a, b = 0, len(text)
|
||||||
|
for step in [64, 32, 16, 8, 4, 2, 1]:
|
||||||
|
while a < b:
|
||||||
|
t = max(a, b - step)
|
||||||
|
if _compare(text[a : t])[0]:
|
||||||
|
break
|
||||||
|
b = t
|
||||||
|
for step in [64, 32, 16, 8, 4, 2, 1]:
|
||||||
|
while a < b:
|
||||||
|
t = min(a + step, b)
|
||||||
|
if _compare(text[t : b])[0]:
|
||||||
|
break
|
||||||
|
a = t
|
||||||
|
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
|
||||||
|
assert a <= b and not ok
|
||||||
|
logger.error(" Text:" + repr(text[a : b]))
|
||||||
|
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b]))
|
||||||
logger.error(" Expected: " + str(ids1))
|
logger.error(" Expected: " + str(ids1))
|
||||||
logger.error(" Result: " + str(ids2))
|
logger.error(" Result: " + str(ids2))
|
||||||
encode_errors += 1
|
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
|
||||||
|
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
|
||||||
logger.error(f" {encode_errors=}")
|
logger.error(f" {encode_errors=}")
|
||||||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
|
||||||
i = find_first_mismatch(text1, text2)
|
|
||||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
|
||||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
|
||||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
|
||||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
|
||||||
decode_errors += 1
|
|
||||||
logger.error(f" {decode_errors=}")
|
logger.error(f" {decode_errors=}")
|
||||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
|
@ -480,6 +550,43 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
|
||||||
|
|
||||||
|
MAX_PRINT_ERRORS = 10
|
||||||
|
|
||||||
|
logger.info("compare_vocabs: ini")
|
||||||
|
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
|
||||||
|
for detokenize in (False, True):
|
||||||
|
vocab1 = tokenizer1.get_vocab(detokenize)
|
||||||
|
vocab2 = tokenizer2.get_vocab(detokenize)
|
||||||
|
if vocab1 != vocab2:
|
||||||
|
num_errors = 0
|
||||||
|
for i in range(max(len(vocab1), len(vocab2))):
|
||||||
|
text1 = vocab1[i] if i < len(vocab1) else None
|
||||||
|
text2 = vocab2[i] if i < len(vocab2) else None
|
||||||
|
if text1 != text2:
|
||||||
|
# is "[UNUSED_TOKEN_" and "[PAD" valid for all models ? #TODO: use toktypes
|
||||||
|
if text1 is not None:
|
||||||
|
text1 = text1.replace("[UNUSED_TOKEN_", "[PAD")
|
||||||
|
if text2 is not None:
|
||||||
|
text2 = text2.replace("[UNUSED_TOKEN_", "[PAD")
|
||||||
|
if text1 is None and (text2 or "").startswith('[PAD'):
|
||||||
|
text2 = None
|
||||||
|
if text2 is None and (text1 or "").startswith('[PAD'):
|
||||||
|
text1 = None
|
||||||
|
if text1 != text2:
|
||||||
|
num_errors += 1
|
||||||
|
if num_errors < MAX_PRINT_ERRORS:
|
||||||
|
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
|
||||||
|
if num_errors:
|
||||||
|
logger.error(f" {num_errors=}")
|
||||||
|
|
||||||
|
t_total = time.perf_counter() - t_start
|
||||||
|
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
|
||||||
|
|
||||||
|
|
||||||
def main(argv: list[str] | None = None):
|
def main(argv: list[str] | None = None):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
||||||
|
@ -493,18 +600,21 @@ def main(argv: list[str] | None = None):
|
||||||
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
|
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
|
||||||
tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
|
tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
|
||||||
|
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
|
compare_vocabs(tokenizer1, tokenizer2)
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
|
|
||||||
|
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
|
||||||
|
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
|
||||||
|
compare_tokenizers(tokenizer1, tokenizer2, generator_byte_tokens())
|
||||||
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
|
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
|
||||||
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
|
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
|
||||||
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
|
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
|
||||||
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
|
||||||
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
|
||||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
|
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
|
||||||
|
|
||||||
tokenizer2.model.free()
|
tokenizer2.model.free()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue