From 614d0bb874963832fa2d2b3e3d25d602c3ef59fc Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Sat, 25 May 2024 04:15:22 +0200 Subject: [PATCH] Update random test: add_eos_token --- tests/test-tokenizer-random.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index f69038c87..9e5a352bd 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -168,11 +168,16 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: for m in range(iterations): rand.seed(m) words = rand.choices(special_tokens, k=500) - if words[0] == tokenizer.bos_token: # skip spam warning of double BOS + if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS words.pop(0) if tokenizer.add_bos_token: # drop all starting BOS words.pop(0) + if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS + while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS + words.pop(-1) + if tokenizer.add_bos_token: # drop all trailing EOS + words.pop(-1) yield "".join(words) @@ -305,7 +310,9 @@ def main(argv: list[str] = None): ids = func_tokenize2("a") assert 1 <= len(ids) <= 3 add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0] + add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1] tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token) + tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token) vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())