Update bruteforce test:

- Faster failing text range selection.
- Show unique failing texts differences.
- Add more recent models.
This commit is contained in:
jaime-m-p 2024-08-05 21:10:45 +02:00
parent fd6d9b9e6a
commit 3b36703c8a

View file

@ -472,10 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_decode1 = 0 t_decode1 = 0
t_decode2 = 0 t_decode2 = 0
t_start = time.perf_counter() t_start = time.perf_counter()
total_tests = 0
failing_texts = set()
encode_errors = 0 encode_errors = 0
decode_errors = 0 decode_errors = 0
total_tests = 0 MAX_ERRORS = 5
MAX_ERRORS = 10
logger.info("%s: %s" % (generator.__qualname__, "ini")) logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator: for text in generator:
@ -494,13 +495,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_encode2 += t2 - t1 t_encode2 += t2 - t1
t_decode1 += t3 - t2 t_decode1 += t3 - t2
t_decode2 += t4 - t3 t_decode2 += t4 - t3
total_tests += 1
# compare # compare
encode_ok = ids1 == ids2 encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2) decode_ok = check_detokenizer(text, text1, text2)
encode_errors += not encode_ok if not (encode_ok and decode_ok):
decode_errors += not decode_ok
total_tests += 1
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
def _compare(text: str): def _compare(text: str):
ids1 = tokenizer1.encode(text) ids1 = tokenizer1.encode(text)
ids2 = tokenizer2.encode(text) ids2 = tokenizer2.encode(text)
@ -510,33 +509,42 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
decode_ok = check_detokenizer(text, text1, text2) decode_ok = check_detokenizer(text, text1, text2)
ok = encode_ok and decode_ok ok = encode_ok and decode_ok
return ok, ids1, ids2, text1, text2 return ok, ids1, ids2, text1, text2
# binary search upper and lower failing range
a, b = 0, len(text) a, b = 0, len(text)
for step in [64, 32, 16, 8, 4, 2, 1]: step = b
while a < b: while step > 1:
t = max(a, b - step) step = step // 2
if _compare(text[a : t])[0]: if not _compare(text[a : b - step])[0]:
break b = b - step
b = t step = b
for step in [64, 32, 16, 8, 4, 2, 1]: while step > 1:
while a < b: step = step // 2
t = min(a + step, b) if not _compare(text[a + step : b])[0]:
if _compare(text[t : b])[0]: a = a + step
break
a = t
ok, ids1, ids2, text1, text2 = _compare(text[a : b]) ok, ids1, ids2, text1, text2 = _compare(text[a : b])
assert a <= b and not ok assert a <= b and not ok
logger.error(" Text:" + repr(text[a : b])) # show unique failing texts differences
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b])) failing_text = text[a : b]
logger.error(" Expected: " + str(ids1)) if failing_text not in failing_texts:
logger.error(" Result: " + str(ids2)) failing_texts.add(failing_text)
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1)) if encode_errors < MAX_ERRORS and not encode_ok:
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2)) encode_errors += 1
logger.error(f" {encode_errors=}") logger.error(f" {encode_errors=}")
logger.error(f" {decode_errors=}") logger.error(" Text:" + repr(failing_text))
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
logger.error(f" EXIT: {encode_errors=} {decode_errors=}") logger.error(" Expected: " + str(ids1))
# raise Exception() logger.error(" Result: " + str(ids2))
break if decode_errors < MAX_ERRORS and not decode_ok:
decode_errors += 1
logger.error(f" {decode_errors=}")
logger.error(" Text:" + repr(failing_text))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
t_total = time.perf_counter() - t_start t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
@ -635,21 +643,19 @@ if __name__ == "__main__":
"phi-3", # SPM "phi-3", # SPM
"gemma", # SPM "gemma", # SPM
"gemma-2", # SPM "gemma-2", # SPM
"baichuan", # SPM # "baichuan", # SPM
"bert-bge", # WPM "bert-bge", # WPM
"jina-v2-en", # WPM "jina-v2-en", # WPM
# "t5", # UGM
"llama-bpe", # BPE "llama-bpe", # BPE
"phi-2", # BPE "phi-2", # BPE
"deepseek-llm", # BPE "deepseek-llm", # BPE
"deepseek-coder", # BPE "deepseek-coder", # BPE
"falcon", # BPE "falcon", # BPE
"mpt", # BPE
"starcoder", # BPE "starcoder", # BPE
"gpt-2", # BPE "gpt-2", # BPE
"stablelm2", # BPE "stablelm2", # BPE
"refact", # BPE "refact", # BPE
"qwen2", # BPE
"olmo", # BPE
"jina-v2-es", # BPE "jina-v2-es", # BPE
"jina-v2-de", # BPE "jina-v2-de", # BPE
"smaug-bpe", # BPE "smaug-bpe", # BPE
@ -657,6 +663,14 @@ if __name__ == "__main__":
"jina-v2-code", # BPE "jina-v2-code", # BPE
"viking", # BPE "viking", # BPE
"jais", # BPE "jais", # BPE
"codeshell", # BPE
"tekken", # BPE
"smollm", # BPE
"mpt", # BPE NFC
"command-r", # BPE NFC
"qwen2", # BPE NFC
"olmo", # BPE NFC
"gpt-neox", # BPE NFC
] ]
logger.info("=" * 50) logger.info("=" * 50)