Update brute force test:

Detokenize special tokens.
Replace errors with '\uFFFD' when detokenizing to 'utf-8'.
More edge cases.
Better detokenization results check.
This commit is contained in:
jaime-m-p 2024-06-24 20:56:26 +02:00
parent 95a0df5578
commit 4a28063b1f

View file

@ -107,7 +107,7 @@ class LibLlamaModel:
while num < 0 and len(self.text_buff) < (16 << 20):
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8")
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
class Tokenizer:
@ -144,7 +144,7 @@ class TokenizerGroundtruth (Tokenizer):
return self.model.encode(text, add_special_tokens=True)
def decode(self, ids: list[int]) -> str:
return self.model.decode(ids, skip_special_tokens=True)
return self.model.decode(ids, skip_special_tokens=False)
class TokenizerLlamaCpp (Tokenizer):
@ -160,7 +160,7 @@ class TokenizerLlamaCpp (Tokenizer):
return self.model.tokenize(text, add_special=True, parse_special=True)
def decode(self, ids: list[int]) -> str:
return self.model.detokenize(ids, special=False)
return self.model.detokenize(ids, special=True)
def generator_custom_text() -> Iterator[str]:
@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'\xa0aC', # deepseek
'\u2029 \uA3E4', # deepseek-llm
"a ?",
'', # mpt
'\U000ac517', # utf-8 encode error, falcon
'\U000522f4', # utf-8 encode error, starcoder
]
@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]:
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
WHITESPACES = ["", " ", " ", " "]
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
for token in all_tokens:
for lstrip in WHITESPACES:
@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]:
def _valid(cpt):
if cpt >= 0x30000: # unassigned and supplement­ary
return False
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
return False
# if cpt == 0x2029: # deepseek-llm
# return False
if unicodedata.category(chr(cpt)) == "Cn":
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
return False
return True
@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
yield "".join(text)
def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: Iterator[str]):
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)):
@ -406,12 +407,25 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
return -1
return min(len(ids1), len(ids2))
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth?
return True
# equal to source text?
if tokenizer1.add_bos_token: # remove BOS
if text2.startswith(tokenizer1.bos_token):
text2 = text2[len(tokenizer1.bos_token):]
if tokenizer1.add_eos_token: # remove EOS
if text2.endswith(tokenizer1.eos_token):
text2 = text2[:-len(tokenizer1.eos_token)]
return text == text2
t_encode1 = 0
t_encode2 = 0
t_decode1 = 0
t_decode2 = 0
t_start = time.perf_counter()
num_errors = 0
encode_errors = 0
decode_errors = 0
logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator:
@ -424,7 +438,7 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
t2 = time.perf_counter()
text1 = tokenizer1.decode(ids1)
t3 = time.perf_counter()
text2 = tokenizer2.decode(ids2)
text2 = tokenizer2.decode(ids1)
t4 = time.perf_counter()
t_encode1 += t1 - t0
t_encode2 += t2 - t1
@ -436,16 +450,18 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
num_errors += 1
if text1 != text2 and text != text2:
encode_errors += 1
logger.error(f" {encode_errors=}")
if not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
num_errors += 1
if num_errors >= 10:
logger.error(f" EXIT: {num_errors=}")
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= 10 or decode_errors >= 10:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
@ -504,6 +520,7 @@ if __name__ == "__main__":
tokenizers = [
"llama-spm", # SPM
"phi-3", # SPM
"baichuan", # SPM
"bert-bge", # WPM
"jina-v2-en", # WPM
"llama-bpe", # BPE