Improve mismatch range localization
This commit is contained in:
parent
9307c3fd46
commit
a943b42416
1 changed files with 36 additions and 20 deletions
|
@ -404,14 +404,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
||||||
|
|
||||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
|
||||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
|
||||||
if a != b:
|
|
||||||
return i
|
|
||||||
if len(ids1) == len(ids2):
|
|
||||||
return -1
|
|
||||||
return min(len(ids1), len(ids2))
|
|
||||||
|
|
||||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||||
return True
|
return True
|
||||||
|
@ -431,6 +423,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
encode_errors = 0
|
encode_errors = 0
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
total_tests = 0
|
||||||
MAX_ERRORS = 10
|
MAX_ERRORS = 10
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
|
@ -450,21 +443,44 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_encode2 += t2 - t1
|
t_encode2 += t2 - t1
|
||||||
t_decode1 += t3 - t2
|
t_decode1 += t3 - t2
|
||||||
t_decode2 += t4 - t3
|
t_decode2 += t4 - t3
|
||||||
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
# compare
|
||||||
i = find_first_mismatch(ids1, ids2)
|
encode_ok = ids1 == ids2
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
encode_errors += not encode_ok
|
||||||
|
decode_errors += not decode_ok
|
||||||
|
total_tests += 1
|
||||||
|
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
|
||||||
|
def _compare(text: str):
|
||||||
|
ids1 = tokenizer1.encode(text)
|
||||||
|
ids2 = tokenizer2.encode(text)
|
||||||
|
text1 = tokenizer1.decode(ids1)
|
||||||
|
text2 = tokenizer2.decode(ids1)
|
||||||
|
encode_ok = ids1 == ids2
|
||||||
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
|
ok = encode_ok and decode_ok
|
||||||
|
return ok, ids1, ids2, text1, text2
|
||||||
|
a, b = 0, len(text)
|
||||||
|
for step in [64, 32, 16, 8, 4, 2, 1]:
|
||||||
|
while a < b:
|
||||||
|
t = max(a, b - step)
|
||||||
|
if _compare(text[a : t])[0]:
|
||||||
|
break
|
||||||
|
b = t
|
||||||
|
for step in [64, 32, 16, 8, 4, 2, 1]:
|
||||||
|
while a < b:
|
||||||
|
t = min(a + step, b)
|
||||||
|
if _compare(text[t : b])[0]:
|
||||||
|
break
|
||||||
|
a = t
|
||||||
|
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
|
||||||
|
assert a <= b and not ok
|
||||||
|
logger.error(" Text:" + repr(text[a : b]))
|
||||||
|
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b]))
|
||||||
logger.error(" Expected: " + str(ids1))
|
logger.error(" Expected: " + str(ids1))
|
||||||
logger.error(" Result: " + str(ids2))
|
logger.error(" Result: " + str(ids2))
|
||||||
encode_errors += 1
|
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
|
||||||
|
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
|
||||||
logger.error(f" {encode_errors=}")
|
logger.error(f" {encode_errors=}")
|
||||||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
|
||||||
i = find_first_mismatch(text1, text2)
|
|
||||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
|
||||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
|
||||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
|
||||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
|
||||||
decode_errors += 1
|
|
||||||
logger.error(f" {decode_errors=}")
|
logger.error(f" {decode_errors=}")
|
||||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue