Update brute force test: testing 'lstrip' and 'rstrip'
This commit is contained in:
parent
8564c1989a
commit
83cac0ec83
1 changed files with 33 additions and 12 deletions
|
@ -161,14 +161,34 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
|
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
|
||||||
special_tokens = set(tokenizer.all_special_tokens)
|
"""Brute force check all vocab words"""
|
||||||
special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"])
|
yield from vocab
|
||||||
special_tokens = list(sorted(special_tokens))
|
|
||||||
|
|
||||||
|
def generator_added_lr_strip(tokenizer) -> Iterator[str]:
|
||||||
|
WHITESPACES = ["", " ", " ", " "]
|
||||||
|
special_tokens = list(tokenizer.all_special_tokens)
|
||||||
|
added_tokens = list(tokenizer.added_tokens_encoder)
|
||||||
|
all_tokens = list(sorted(set(special_tokens + added_tokens)))
|
||||||
|
for token in all_tokens:
|
||||||
|
for lstrip in WHITESPACES:
|
||||||
|
for rstrip in WHITESPACES:
|
||||||
|
yield lstrip + token + rstrip
|
||||||
|
yield "a" + lstrip + token + rstrip
|
||||||
|
yield lstrip + token + rstrip + "z"
|
||||||
|
yield "a" + lstrip + token + rstrip + "z"
|
||||||
|
|
||||||
|
|
||||||
|
def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
|
||||||
|
special_tokens = list(tokenizer.all_special_tokens)
|
||||||
|
added_tokens = list(tokenizer.added_tokens_encoder)
|
||||||
|
separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
|
||||||
|
all_tokens = list(sorted(set(special_tokens + added_tokens + separations)))
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
words = rand.choices(special_tokens, k=500)
|
words = rand.choices(all_tokens, k=500)
|
||||||
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
|
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
|
||||||
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
|
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
|
||||||
words.pop(0)
|
words.pop(0)
|
||||||
|
@ -276,8 +296,8 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
|
||||||
ids2 = func_tokenize2(text)
|
ids2 = func_tokenize2(text)
|
||||||
if ids1 != ids2:
|
if ids1 != ids2:
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||||
logger.info(" TokenIDs: " + str(ids1))
|
logger.info(" TokenIDs: " + str(ids1))
|
||||||
logger.info(" Expected: " + str(ids2))
|
logger.info(" Expected: " + str(ids2))
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
@ -311,8 +331,9 @@ def main(argv: list[str] = None):
|
||||||
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000))
|
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
|
||||||
|
@ -330,10 +351,10 @@ if __name__ == "__main__":
|
||||||
# import os
|
# import os
|
||||||
# tokenizers = os.listdir(path_tokenizers)
|
# tokenizers = os.listdir(path_tokenizers)
|
||||||
tokenizers = [
|
tokenizers = [
|
||||||
# "llama-spm", # SPM
|
"llama-spm", # SPM
|
||||||
"phi-3", # SPM
|
"phi-3", # SPM
|
||||||
# "jina-v2-en", # WPM
|
"jina-v2-en", # WPM
|
||||||
# "bert-bge", # WPM
|
"bert-bge", # WPM
|
||||||
]
|
]
|
||||||
|
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue