Update bruteforce test
This commit is contained in:
parent
107923cdd2
commit
68220feaf8
1 changed files with 7 additions and 4 deletions
|
@ -235,6 +235,8 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
'å', # mpt
|
'å', # mpt
|
||||||
'\U000ac517', # utf-8 encode error, falcon
|
'\U000ac517', # utf-8 encode error, falcon
|
||||||
'\U000522f4', # utf-8 encode error, starcoder
|
'\U000522f4', # utf-8 encode error, starcoder
|
||||||
|
"<s><s><unk><s>a<s>b<s>c<unk>d<unk></s>",
|
||||||
|
"<s> <s> <unk><s>a<s>b<s>c<unk>d<unk></s>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,7 +336,7 @@ def generator_unicodes() -> Iterator[str]:
|
||||||
return False
|
return False
|
||||||
# if cpt == 0x2029: # deepseek-llm
|
# if cpt == 0x2029: # deepseek-llm
|
||||||
# return False
|
# return False
|
||||||
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
|
if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"): # undefined, surrogates, private
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -426,6 +428,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
encode_errors = 0
|
encode_errors = 0
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
MAX_ERRORS = 10
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
|
@ -444,7 +447,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
t_encode2 += t2 - t1
|
t_encode2 += t2 - t1
|
||||||
t_decode1 += t3 - t2
|
t_decode1 += t3 - t2
|
||||||
t_decode2 += t4 - t3
|
t_decode2 += t4 - t3
|
||||||
if ids1 != ids2:
|
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||||
|
@ -452,7 +455,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
logger.error(" Result: " + str(ids2))
|
logger.error(" Result: " + str(ids2))
|
||||||
encode_errors += 1
|
encode_errors += 1
|
||||||
logger.error(f" {encode_errors=}")
|
logger.error(f" {encode_errors=}")
|
||||||
if not check_detokenizer(text, text1, text2):
|
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
||||||
i = find_first_mismatch(text1, text2)
|
i = find_first_mismatch(text1, text2)
|
||||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
||||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
||||||
|
@ -460,7 +463,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
||||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
||||||
decode_errors += 1
|
decode_errors += 1
|
||||||
logger.error(f" {decode_errors=}")
|
logger.error(f" {decode_errors=}")
|
||||||
if encode_errors >= 10 or decode_errors >= 10:
|
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
# raise Exception()
|
# raise Exception()
|
||||||
break
|
break
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue