diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index f3447d482..4a5773fa5 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -472,10 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl t_decode1 = 0 t_decode2 = 0 t_start = time.perf_counter() + total_tests = 0 + failing_texts = set() encode_errors = 0 decode_errors = 0 - total_tests = 0 - MAX_ERRORS = 10 + MAX_ERRORS = 5 logger.info("%s: %s" % (generator.__qualname__, "ini")) for text in generator: @@ -494,13 +495,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl t_encode2 += t2 - t1 t_decode1 += t3 - t2 t_decode2 += t4 - t3 + total_tests += 1 # compare encode_ok = ids1 == ids2 decode_ok = check_detokenizer(text, text1, text2) - encode_errors += not encode_ok - decode_errors += not decode_ok - total_tests += 1 - if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok): + if not (encode_ok and decode_ok): def _compare(text: str): ids1 = tokenizer1.encode(text) ids2 = tokenizer2.encode(text) @@ -510,33 +509,42 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl decode_ok = check_detokenizer(text, text1, text2) ok = encode_ok and decode_ok return ok, ids1, ids2, text1, text2 + # binary search upper and lower failing range a, b = 0, len(text) - for step in [64, 32, 16, 8, 4, 2, 1]: - while a < b: - t = max(a, b - step) - if _compare(text[a : t])[0]: - break - b = t - for step in [64, 32, 16, 8, 4, 2, 1]: - while a < b: - t = min(a + step, b) - if _compare(text[t : b])[0]: - break - a = t + step = b + while step > 1: + step = step // 2 + if not _compare(text[a : b - step])[0]: + b = b - step + step = b + while step > 1: + step = step // 2 + if not _compare(text[a + step : b])[0]: + a = a + step ok, ids1, ids2, text1, text2 = _compare(text[a : b]) assert a <= b and not ok - logger.error(" Text:" + repr(text[a : b])) - logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b])) - logger.error(" Expected: " + str(ids1)) - logger.error(" Result: " + str(ids2)) - logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1)) - logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2)) - logger.error(f" {encode_errors=}") - logger.error(f" {decode_errors=}") - if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: - logger.error(f" EXIT: {encode_errors=} {decode_errors=}") - # raise Exception() - break + # show unique failing texts differences + failing_text = text[a : b] + if failing_text not in failing_texts: + failing_texts.add(failing_text) + if encode_errors < MAX_ERRORS and not encode_ok: + encode_errors += 1 + logger.error(f" {encode_errors=}") + logger.error(" Text:" + repr(failing_text)) + logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text)) + logger.error(" Expected: " + str(ids1)) + logger.error(" Result: " + str(ids2)) + if decode_errors < MAX_ERRORS and not decode_ok: + decode_errors += 1 + logger.error(f" {decode_errors=}") + logger.error(" Text:" + repr(failing_text)) + logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text)) + logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1)) + logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2)) + if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: + logger.error(f" EXIT: {encode_errors=} {decode_errors=}") + # raise Exception() + break t_total = time.perf_counter() - t_start logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") @@ -635,21 +643,19 @@ if __name__ == "__main__": "phi-3", # SPM "gemma", # SPM "gemma-2", # SPM - "baichuan", # SPM + # "baichuan", # SPM "bert-bge", # WPM "jina-v2-en", # WPM + # "t5", # UGM "llama-bpe", # BPE "phi-2", # BPE "deepseek-llm", # BPE "deepseek-coder", # BPE "falcon", # BPE - "mpt", # BPE "starcoder", # BPE "gpt-2", # BPE "stablelm2", # BPE "refact", # BPE - "qwen2", # BPE - "olmo", # BPE "jina-v2-es", # BPE "jina-v2-de", # BPE "smaug-bpe", # BPE @@ -657,6 +663,14 @@ if __name__ == "__main__": "jina-v2-code", # BPE "viking", # BPE "jais", # BPE + "codeshell", # BPE + "tekken", # BPE + "smollm", # BPE + "mpt", # BPE NFC + "command-r", # BPE NFC + "qwen2", # BPE NFC + "olmo", # BPE NFC + "gpt-neox", # BPE NFC ] logger.info("=" * 50)