From a943b424169b051d37f219043c10b17d6efa7a44 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Tue, 9 Jul 2024 01:02:44 +0200 Subject: [PATCH] Improve mismatch range localization --- tests/test-tokenizer-random.py | 56 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 5b31cfc9c..9c261fb6a 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -404,14 +404,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): - def find_first_mismatch(ids1: list[int], ids2: list[int]): - for i, (a, b) in enumerate(zip(ids1, ids2)): - if a != b: - return i - if len(ids1) == len(ids2): - return -1 - return min(len(ids1), len(ids2)) - def check_detokenizer(text: str, text1: str, text2: str) -> bool: if text1 == text2: # equal to TokenizerGroundtruth? return True @@ -431,6 +423,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl t_start = time.perf_counter() encode_errors = 0 decode_errors = 0 + total_tests = 0 MAX_ERRORS = 10 logger.info("%s: %s" % (generator.__name__, "ini")) @@ -450,21 +443,44 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl t_encode2 += t2 - t1 t_decode1 += t3 - t2 t_decode2 += t4 - t3 - if encode_errors < MAX_ERRORS and ids1 != ids2: - i = find_first_mismatch(ids1, ids2) - ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] - ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] + # compare + encode_ok = ids1 == ids2 + decode_ok = check_detokenizer(text, text1, text2) + encode_errors += not encode_ok + decode_errors += not decode_ok + total_tests += 1 + if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok): + def _compare(text: str): + ids1 = tokenizer1.encode(text) + ids2 = tokenizer2.encode(text) + text1 = tokenizer1.decode(ids1) + text2 = tokenizer2.decode(ids1) + encode_ok = ids1 == ids2 + decode_ok = check_detokenizer(text, text1, text2) + ok = encode_ok and decode_ok + return ok, ids1, ids2, text1, text2 + a, b = 0, len(text) + for step in [64, 32, 16, 8, 4, 2, 1]: + while a < b: + t = max(a, b - step) + if _compare(text[a : t])[0]: + break + b = t + for step in [64, 32, 16, 8, 4, 2, 1]: + while a < b: + t = min(a + step, b) + if _compare(text[t : b])[0]: + break + a = t + ok, ids1, ids2, text1, text2 = _compare(text[a : b]) + assert a <= b and not ok + logger.error(" Text:" + repr(text[a : b])) + logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b])) logger.error(" Expected: " + str(ids1)) logger.error(" Result: " + str(ids2)) - encode_errors += 1 + logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1)) + logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2)) logger.error(f" {encode_errors=}") - if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): - i = find_first_mismatch(text1, text2) - text1 = list(text1[max(0, i - 2) : i + 5 + 1]) - text2 = list(text2[max(0, i - 2) : i + 5 + 1]) - logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) - logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) - decode_errors += 1 logger.error(f" {decode_errors=}") if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: logger.error(f" EXIT: {encode_errors=} {decode_errors=}")