update brute force random test

This commit is contained in:
jaime-m-p 2024-06-13 20:39:55 +02:00
parent 75840fe6a6
commit f58de3174e

View file

@ -161,6 +161,7 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2
'\xa0aC', # deepseek
]
@ -208,6 +209,7 @@ def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters"""
NUM_WORDS = 400
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(sorted(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ
@ -221,12 +223,38 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
for m in range(iterations):
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
for _ in range(NUM_WORDS):
k = rand.randint(1, 7)
word = rand.choices(CHARS, k=k)
space = rand.choice(WHITESPACES)
text.append("".join(word) + space)
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
def generator_random_unicodes(iterations=100) -> Iterator[str]:
"""Brute force random text with unicode characters"""
NUM_WORDS = 200
MAX_CODEPOINTS = 0x110000
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
def _valid(c):
if 0x0D800 <= c <= 0x0DFFF: return False # surrogates
if 0x0D800 <= c <= 0x0DFFF: return False # unassigned
if 0x60000 <= c <= 0x6FFFF: return False # unassigned
if 0x80000 <= c <= 0x8FFFF: return False # unassigned
return True
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
for _ in range(NUM_WORDS):
k = rand.randint(1, 7)
word = [rand.randrange(1, MAX_CODEPOINTS) for _ in range(k)]
word = [chr(c) for c in word if _valid(c)]
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
@ -264,25 +292,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
yield "".join(text)
def generator_random_bytes(iterations=100) -> Iterator[str]:
"""Brute force random bytes"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
k = rand.randint(1, 8)
word = [chr(r) for r in rand.randbytes(k) if r]
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)):
@ -292,11 +302,19 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
return -1
return min(len(ids1), len(ids2))
t0 = time.perf_counter()
t_tokenizer1 = 0
t_tokenizer2 = 0
t_start = time.perf_counter()
logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator:
t0 = time.perf_counter()
ids1 = func_tokenize1(text)
t1 = time.perf_counter()
ids2 = func_tokenize2(text)
t2 = time.perf_counter()
t_tokenizer1 += t1 - t0
t_tokenizer2 += t2 - t1
if ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
@ -304,8 +322,9 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2))
raise Exception()
t1 = time.perf_counter()
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
t_total = time.perf_counter() - t_start
logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
def main(argv: list[str] = None):
@ -334,15 +353,16 @@ def main(argv: list[str] = None):
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
model.free()
@ -356,21 +376,21 @@ if __name__ == "__main__":
# import os
# tokenizers = os.listdir(path_tokenizers)
tokenizers = [
# "llama-spm", # SPM
# "phi-3", # SPM
# "bert-bge", # WPM
# "jina-v2-en", # WPM
"llama-spm", # SPM
"phi-3", # SPM
"bert-bge", # WPM
"jina-v2-en", # WPM
"gpt-2", # BPE
"llama-bpe", # BPE
"falcon", # BPE
"deepseek-coder", # BPE
"deepseek-llm", # BPE
"starcoder", # BPE
"jina-v2-es", # BPE
"jina-v2-de", # BPE
"jina-v2-code" # BPE
"smaug-bpe" # BPE
"jina-v2-code", # BPE
"smaug-bpe", # BPE
"phi-2", # BPE
"deepseek-coder", # BPE
"deepseek-llm", # BPE
]
for tokenizer in tokenizers: