Improve mismatch range localization

This commit is contained in:
jaime-m-p 2024-07-09 01:02:44 +02:00
parent 9307c3fd46
commit a943b42416

View file

@ -404,14 +404,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth?
return True
@ -431,6 +423,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_start = time.perf_counter()
encode_errors = 0
decode_errors = 0
total_tests = 0
MAX_ERRORS = 10
logger.info("%s: %s" % (generator.__name__, "ini"))
@ -450,21 +443,44 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_encode2 += t2 - t1
t_decode1 += t3 - t2
t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
# compare
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
encode_errors += not encode_ok
decode_errors += not decode_ok
total_tests += 1
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
def _compare(text: str):
ids1 = tokenizer1.encode(text)
ids2 = tokenizer2.encode(text)
text1 = tokenizer1.decode(ids1)
text2 = tokenizer2.decode(ids1)
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
ok = encode_ok and decode_ok
return ok, ids1, ids2, text1, text2
a, b = 0, len(text)
for step in [64, 32, 16, 8, 4, 2, 1]:
while a < b:
t = max(a, b - step)
if _compare(text[a : t])[0]:
break
b = t
for step in [64, 32, 16, 8, 4, 2, 1]:
while a < b:
t = min(a + step, b)
if _compare(text[t : b])[0]:
break
a = t
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
assert a <= b and not ok
logger.error(" Text:" + repr(text[a : b]))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b]))
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
encode_errors += 1
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")