From 83cac0ec836bcb7066b2468dac0f5ac48943f7a9 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Mon, 3 Jun 2024 00:51:48 +0200 Subject: [PATCH] Update brute force test: testing 'lstrip' and 'rstrip' --- tests/test-tokenizer-random.py | 45 +++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 9a84d9379..f699af022 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -161,14 +161,34 @@ def generator_custom_text_edge_cases() -> Iterator[str]: ] -def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: - special_tokens = set(tokenizer.all_special_tokens) - special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "", ""]) - special_tokens = list(sorted(special_tokens)) +def generator_vocab_words(vocab: list[str]) -> Iterator[str]: + """Brute force check all vocab words""" + yield from vocab + + +def generator_added_lr_strip(tokenizer) -> Iterator[str]: + WHITESPACES = ["", " ", " ", " "] + special_tokens = list(tokenizer.all_special_tokens) + added_tokens = list(tokenizer.added_tokens_encoder) + all_tokens = list(sorted(set(special_tokens + added_tokens))) + for token in all_tokens: + for lstrip in WHITESPACES: + for rstrip in WHITESPACES: + yield lstrip + token + rstrip + yield "a" + lstrip + token + rstrip + yield lstrip + token + rstrip + "z" + yield "a" + lstrip + token + rstrip + "z" + + +def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]: + special_tokens = list(tokenizer.all_special_tokens) + added_tokens = list(tokenizer.added_tokens_encoder) + separations = [" ", "\n", "\t", "-", "!", "one", "1", "", ""] + all_tokens = list(sorted(set(special_tokens + added_tokens + separations))) rand = random.Random() for m in range(iterations): rand.seed(m) - words = rand.choices(special_tokens, k=500) + words = rand.choices(all_tokens, k=500) if words[0] == tokenizer.bos_token: # skip spam warning of double BOS while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS words.pop(0) @@ -276,8 +296,8 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g ids2 = func_tokenize2(text) if ids1 != ids2: i = find_first_mismatch(ids1, ids2) - ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1] - ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1] + ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] + ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] logger.info(" TokenIDs: " + str(ids1)) logger.info(" Expected: " + str(ids2)) raise Exception() @@ -311,8 +331,9 @@ def main(argv: list[str] = None): vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text()) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) - test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab)) + test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer)) + test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000)) @@ -324,16 +345,16 @@ def main(argv: list[str] = None): if __name__ == "__main__": # main() - path_tokenizers = "./models/tokenizers/" + path_tokenizers = "./models/tokenizers/" path_vocab_format = "./models/ggml-vocab-%s.gguf" # import os # tokenizers = os.listdir(path_tokenizers) tokenizers = [ - # "llama-spm", # SPM + "llama-spm", # SPM "phi-3", # SPM - # "jina-v2-en", # WPM - # "bert-bge", # WPM + "jina-v2-en", # WPM + "bert-bge", # WPM ] for tokenizer in tokenizers: