Update brute force test:
Detokenize special tokens. Replace errors with '\uFFFD' when detokenizing to 'utf-8'. More edge cases. Better detokenization results check.
This commit is contained in:
parent
95a0df5578
commit
4a28063b1f
1 changed files with 32 additions and 15 deletions
|
@ -107,7 +107,7 @@ class LibLlamaModel:
|
|||
while num < 0 and len(self.text_buff) < (16 << 20):
|
||||
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
|
||||
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
|
||||
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8")
|
||||
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
@ -144,7 +144,7 @@ class TokenizerGroundtruth (Tokenizer):
|
|||
return self.model.encode(text, add_special_tokens=True)
|
||||
|
||||
def decode(self, ids: list[int]) -> str:
|
||||
return self.model.decode(ids, skip_special_tokens=True)
|
||||
return self.model.decode(ids, skip_special_tokens=False)
|
||||
|
||||
|
||||
class TokenizerLlamaCpp (Tokenizer):
|
||||
|
@ -160,7 +160,7 @@ class TokenizerLlamaCpp (Tokenizer):
|
|||
return self.model.tokenize(text, add_special=True, parse_special=True)
|
||||
|
||||
def decode(self, ids: list[int]) -> str:
|
||||
return self.model.detokenize(ids, special=False)
|
||||
return self.model.detokenize(ids, special=True)
|
||||
|
||||
|
||||
def generator_custom_text() -> Iterator[str]:
|
||||
|
@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
|||
'\xa0aC', # deepseek
|
||||
'\u2029 \uA3E4', # deepseek-llm
|
||||
"a ?",
|
||||
'å', # mpt
|
||||
'\U000ac517', # utf-8 encode error, falcon
|
||||
'\U000522f4', # utf-8 encode error, starcoder
|
||||
]
|
||||
|
||||
|
||||
|
@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]:
|
|||
|
||||
|
||||
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
||||
WHITESPACES = ["", " ", " ", " "]
|
||||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
|
||||
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
||||
for token in all_tokens:
|
||||
for lstrip in WHITESPACES:
|
||||
|
@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]:
|
|||
def _valid(cpt):
|
||||
if cpt >= 0x30000: # unassigned and supplementary
|
||||
return False
|
||||
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
|
||||
return False
|
||||
# if cpt == 0x2029: # deepseek-llm
|
||||
# return False
|
||||
if unicodedata.category(chr(cpt)) == "Cn":
|
||||
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
|||
yield "".join(text)
|
||||
|
||||
|
||||
def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: Iterator[str]):
|
||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||
|
||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||
|
@ -406,12 +407,25 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
|||
return -1
|
||||
return min(len(ids1), len(ids2))
|
||||
|
||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||
return True
|
||||
# equal to source text?
|
||||
if tokenizer1.add_bos_token: # remove BOS
|
||||
if text2.startswith(tokenizer1.bos_token):
|
||||
text2 = text2[len(tokenizer1.bos_token):]
|
||||
if tokenizer1.add_eos_token: # remove EOS
|
||||
if text2.endswith(tokenizer1.eos_token):
|
||||
text2 = text2[:-len(tokenizer1.eos_token)]
|
||||
return text == text2
|
||||
|
||||
t_encode1 = 0
|
||||
t_encode2 = 0
|
||||
t_decode1 = 0
|
||||
t_decode2 = 0
|
||||
t_start = time.perf_counter()
|
||||
num_errors = 0
|
||||
encode_errors = 0
|
||||
decode_errors = 0
|
||||
|
||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||
for text in generator:
|
||||
|
@ -424,7 +438,7 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
|||
t2 = time.perf_counter()
|
||||
text1 = tokenizer1.decode(ids1)
|
||||
t3 = time.perf_counter()
|
||||
text2 = tokenizer2.decode(ids2)
|
||||
text2 = tokenizer2.decode(ids1)
|
||||
t4 = time.perf_counter()
|
||||
t_encode1 += t1 - t0
|
||||
t_encode2 += t2 - t1
|
||||
|
@ -436,16 +450,18 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
|||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||
logger.error(" Expected: " + str(ids1))
|
||||
logger.error(" Result: " + str(ids2))
|
||||
num_errors += 1
|
||||
if text1 != text2 and text != text2:
|
||||
encode_errors += 1
|
||||
logger.error(f" {encode_errors=}")
|
||||
if not check_detokenizer(text, text1, text2):
|
||||
i = find_first_mismatch(text1, text2)
|
||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
||||
num_errors += 1
|
||||
if num_errors >= 10:
|
||||
logger.error(f" EXIT: {num_errors=}")
|
||||
decode_errors += 1
|
||||
logger.error(f" {decode_errors=}")
|
||||
if encode_errors >= 10 or decode_errors >= 10:
|
||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||
# raise Exception()
|
||||
break
|
||||
|
||||
|
@ -504,6 +520,7 @@ if __name__ == "__main__":
|
|||
tokenizers = [
|
||||
"llama-spm", # SPM
|
||||
"phi-3", # SPM
|
||||
"baichuan", # SPM
|
||||
"bert-bge", # WPM
|
||||
"jina-v2-en", # WPM
|
||||
"llama-bpe", # BPE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue