Update bruteforce test:
- Faster failing text range selection. - Show unique failing texts differences. - Add more recent models.
This commit is contained in:
parent
fd6d9b9e6a
commit
3b36703c8a
1 changed files with 48 additions and 34 deletions
|
@ -472,10 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_decode1 = 0
|
t_decode1 = 0
|
||||||
t_decode2 = 0
|
t_decode2 = 0
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
|
total_tests = 0
|
||||||
|
failing_texts = set()
|
||||||
encode_errors = 0
|
encode_errors = 0
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
total_tests = 0
|
MAX_ERRORS = 5
|
||||||
MAX_ERRORS = 10
|
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
|
@ -494,13 +495,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_encode2 += t2 - t1
|
t_encode2 += t2 - t1
|
||||||
t_decode1 += t3 - t2
|
t_decode1 += t3 - t2
|
||||||
t_decode2 += t4 - t3
|
t_decode2 += t4 - t3
|
||||||
|
total_tests += 1
|
||||||
# compare
|
# compare
|
||||||
encode_ok = ids1 == ids2
|
encode_ok = ids1 == ids2
|
||||||
decode_ok = check_detokenizer(text, text1, text2)
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
encode_errors += not encode_ok
|
if not (encode_ok and decode_ok):
|
||||||
decode_errors += not decode_ok
|
|
||||||
total_tests += 1
|
|
||||||
if (encode_errors < MAX_ERRORS and not encode_ok) or (decode_errors < MAX_ERRORS and not decode_ok):
|
|
||||||
def _compare(text: str):
|
def _compare(text: str):
|
||||||
ids1 = tokenizer1.encode(text)
|
ids1 = tokenizer1.encode(text)
|
||||||
ids2 = tokenizer2.encode(text)
|
ids2 = tokenizer2.encode(text)
|
||||||
|
@ -510,33 +509,42 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
decode_ok = check_detokenizer(text, text1, text2)
|
decode_ok = check_detokenizer(text, text1, text2)
|
||||||
ok = encode_ok and decode_ok
|
ok = encode_ok and decode_ok
|
||||||
return ok, ids1, ids2, text1, text2
|
return ok, ids1, ids2, text1, text2
|
||||||
|
# binary search upper and lower failing range
|
||||||
a, b = 0, len(text)
|
a, b = 0, len(text)
|
||||||
for step in [64, 32, 16, 8, 4, 2, 1]:
|
step = b
|
||||||
while a < b:
|
while step > 1:
|
||||||
t = max(a, b - step)
|
step = step // 2
|
||||||
if _compare(text[a : t])[0]:
|
if not _compare(text[a : b - step])[0]:
|
||||||
break
|
b = b - step
|
||||||
b = t
|
step = b
|
||||||
for step in [64, 32, 16, 8, 4, 2, 1]:
|
while step > 1:
|
||||||
while a < b:
|
step = step // 2
|
||||||
t = min(a + step, b)
|
if not _compare(text[a + step : b])[0]:
|
||||||
if _compare(text[t : b])[0]:
|
a = a + step
|
||||||
break
|
|
||||||
a = t
|
|
||||||
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
|
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
|
||||||
assert a <= b and not ok
|
assert a <= b and not ok
|
||||||
logger.error(" Text:" + repr(text[a : b]))
|
# show unique failing texts differences
|
||||||
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text[a : b]))
|
failing_text = text[a : b]
|
||||||
logger.error(" Expected: " + str(ids1))
|
if failing_text not in failing_texts:
|
||||||
logger.error(" Result: " + str(ids2))
|
failing_texts.add(failing_text)
|
||||||
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
|
if encode_errors < MAX_ERRORS and not encode_ok:
|
||||||
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
|
encode_errors += 1
|
||||||
logger.error(f" {encode_errors=}")
|
logger.error(f" {encode_errors=}")
|
||||||
logger.error(f" {decode_errors=}")
|
logger.error(" Text:" + repr(failing_text))
|
||||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
|
||||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
logger.error(" Expected: " + str(ids1))
|
||||||
# raise Exception()
|
logger.error(" Result: " + str(ids2))
|
||||||
break
|
if decode_errors < MAX_ERRORS and not decode_ok:
|
||||||
|
decode_errors += 1
|
||||||
|
logger.error(f" {decode_errors=}")
|
||||||
|
logger.error(" Text:" + repr(failing_text))
|
||||||
|
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
|
||||||
|
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
|
||||||
|
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
|
||||||
|
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||||
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
|
# raise Exception()
|
||||||
|
break
|
||||||
|
|
||||||
t_total = time.perf_counter() - t_start
|
t_total = time.perf_counter() - t_start
|
||||||
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||||
|
@ -635,21 +643,19 @@ if __name__ == "__main__":
|
||||||
"phi-3", # SPM
|
"phi-3", # SPM
|
||||||
"gemma", # SPM
|
"gemma", # SPM
|
||||||
"gemma-2", # SPM
|
"gemma-2", # SPM
|
||||||
"baichuan", # SPM
|
# "baichuan", # SPM
|
||||||
"bert-bge", # WPM
|
"bert-bge", # WPM
|
||||||
"jina-v2-en", # WPM
|
"jina-v2-en", # WPM
|
||||||
|
# "t5", # UGM
|
||||||
"llama-bpe", # BPE
|
"llama-bpe", # BPE
|
||||||
"phi-2", # BPE
|
"phi-2", # BPE
|
||||||
"deepseek-llm", # BPE
|
"deepseek-llm", # BPE
|
||||||
"deepseek-coder", # BPE
|
"deepseek-coder", # BPE
|
||||||
"falcon", # BPE
|
"falcon", # BPE
|
||||||
"mpt", # BPE
|
|
||||||
"starcoder", # BPE
|
"starcoder", # BPE
|
||||||
"gpt-2", # BPE
|
"gpt-2", # BPE
|
||||||
"stablelm2", # BPE
|
"stablelm2", # BPE
|
||||||
"refact", # BPE
|
"refact", # BPE
|
||||||
"qwen2", # BPE
|
|
||||||
"olmo", # BPE
|
|
||||||
"jina-v2-es", # BPE
|
"jina-v2-es", # BPE
|
||||||
"jina-v2-de", # BPE
|
"jina-v2-de", # BPE
|
||||||
"smaug-bpe", # BPE
|
"smaug-bpe", # BPE
|
||||||
|
@ -657,6 +663,14 @@ if __name__ == "__main__":
|
||||||
"jina-v2-code", # BPE
|
"jina-v2-code", # BPE
|
||||||
"viking", # BPE
|
"viking", # BPE
|
||||||
"jais", # BPE
|
"jais", # BPE
|
||||||
|
"codeshell", # BPE
|
||||||
|
"tekken", # BPE
|
||||||
|
"smollm", # BPE
|
||||||
|
"mpt", # BPE NFC
|
||||||
|
"command-r", # BPE NFC
|
||||||
|
"qwen2", # BPE NFC
|
||||||
|
"olmo", # BPE NFC
|
||||||
|
"gpt-neox", # BPE NFC
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue