Update brute force test:
Detokenize special tokens. Replace errors with '\uFFFD' when detokenizing to 'utf-8'. More edge cases. Better detokenization results check.
This commit is contained in:
parent
95a0df5578
commit
4a28063b1f
1 changed files with 32 additions and 15 deletions
|
@ -107,7 +107,7 @@ class LibLlamaModel:
|
||||||
while num < 0 and len(self.text_buff) < (16 << 20):
|
while num < 0 and len(self.text_buff) < (16 << 20):
|
||||||
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
|
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
|
||||||
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
|
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
|
||||||
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8")
|
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
|
@ -144,7 +144,7 @@ class TokenizerGroundtruth (Tokenizer):
|
||||||
return self.model.encode(text, add_special_tokens=True)
|
return self.model.encode(text, add_special_tokens=True)
|
||||||
|
|
||||||
def decode(self, ids: list[int]) -> str:
|
def decode(self, ids: list[int]) -> str:
|
||||||
return self.model.decode(ids, skip_special_tokens=True)
|
return self.model.decode(ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerLlamaCpp (Tokenizer):
|
class TokenizerLlamaCpp (Tokenizer):
|
||||||
|
@ -160,7 +160,7 @@ class TokenizerLlamaCpp (Tokenizer):
|
||||||
return self.model.tokenize(text, add_special=True, parse_special=True)
|
return self.model.tokenize(text, add_special=True, parse_special=True)
|
||||||
|
|
||||||
def decode(self, ids: list[int]) -> str:
|
def decode(self, ids: list[int]) -> str:
|
||||||
return self.model.detokenize(ids, special=False)
|
return self.model.detokenize(ids, special=True)
|
||||||
|
|
||||||
|
|
||||||
def generator_custom_text() -> Iterator[str]:
|
def generator_custom_text() -> Iterator[str]:
|
||||||
|
@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
'\xa0aC', # deepseek
|
'\xa0aC', # deepseek
|
||||||
'\u2029 \uA3E4', # deepseek-llm
|
'\u2029 \uA3E4', # deepseek-llm
|
||||||
"a ?",
|
"a ?",
|
||||||
|
'å', # mpt
|
||||||
|
'\U000ac517', # utf-8 encode error, falcon
|
||||||
|
'\U000522f4', # utf-8 encode error, starcoder
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]:
|
||||||
|
|
||||||
|
|
||||||
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
||||||
WHITESPACES = ["", " ", " ", " "]
|
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
|
||||||
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
||||||
for token in all_tokens:
|
for token in all_tokens:
|
||||||
for lstrip in WHITESPACES:
|
for lstrip in WHITESPACES:
|
||||||
|
@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]:
|
||||||
def _valid(cpt):
|
def _valid(cpt):
|
||||||
if cpt >= 0x30000: # unassigned and supplementary
|
if cpt >= 0x30000: # unassigned and supplementary
|
||||||
return False
|
return False
|
||||||
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
|
|
||||||
return False
|
|
||||||
# if cpt == 0x2029: # deepseek-llm
|
# if cpt == 0x2029: # deepseek-llm
|
||||||
# return False
|
# return False
|
||||||
if unicodedata.category(chr(cpt)) == "Cn":
|
if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: Iterator[str]):
|
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||||
|
@ -406,12 +407,25 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
||||||
return -1
|
return -1
|
||||||
return min(len(ids1), len(ids2))
|
return min(len(ids1), len(ids2))
|
||||||
|
|
||||||
|
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||||
|
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||||
|
return True
|
||||||
|
# equal to source text?
|
||||||
|
if tokenizer1.add_bos_token: # remove BOS
|
||||||
|
if text2.startswith(tokenizer1.bos_token):
|
||||||
|
text2 = text2[len(tokenizer1.bos_token):]
|
||||||
|
if tokenizer1.add_eos_token: # remove EOS
|
||||||
|
if text2.endswith(tokenizer1.eos_token):
|
||||||
|
text2 = text2[:-len(tokenizer1.eos_token)]
|
||||||
|
return text == text2
|
||||||
|
|
||||||
t_encode1 = 0
|
t_encode1 = 0
|
||||||
t_encode2 = 0
|
t_encode2 = 0
|
||||||
t_decode1 = 0
|
t_decode1 = 0
|
||||||
t_decode2 = 0
|
t_decode2 = 0
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
num_errors = 0
|
encode_errors = 0
|
||||||
|
decode_errors = 0
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
|
@ -424,7 +438,7 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
text1 = tokenizer1.decode(ids1)
|
text1 = tokenizer1.decode(ids1)
|
||||||
t3 = time.perf_counter()
|
t3 = time.perf_counter()
|
||||||
text2 = tokenizer2.decode(ids2)
|
text2 = tokenizer2.decode(ids1)
|
||||||
t4 = time.perf_counter()
|
t4 = time.perf_counter()
|
||||||
t_encode1 += t1 - t0
|
t_encode1 += t1 - t0
|
||||||
t_encode2 += t2 - t1
|
t_encode2 += t2 - t1
|
||||||
|
@ -436,16 +450,18 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator:
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||||
logger.error(" Expected: " + str(ids1))
|
logger.error(" Expected: " + str(ids1))
|
||||||
logger.error(" Result: " + str(ids2))
|
logger.error(" Result: " + str(ids2))
|
||||||
num_errors += 1
|
encode_errors += 1
|
||||||
if text1 != text2 and text != text2:
|
logger.error(f" {encode_errors=}")
|
||||||
|
if not check_detokenizer(text, text1, text2):
|
||||||
i = find_first_mismatch(text1, text2)
|
i = find_first_mismatch(text1, text2)
|
||||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
||||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
||||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
||||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
||||||
num_errors += 1
|
decode_errors += 1
|
||||||
if num_errors >= 10:
|
logger.error(f" {decode_errors=}")
|
||||||
logger.error(f" EXIT: {num_errors=}")
|
if encode_errors >= 10 or decode_errors >= 10:
|
||||||
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
# raise Exception()
|
# raise Exception()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -504,6 +520,7 @@ if __name__ == "__main__":
|
||||||
tokenizers = [
|
tokenizers = [
|
||||||
"llama-spm", # SPM
|
"llama-spm", # SPM
|
||||||
"phi-3", # SPM
|
"phi-3", # SPM
|
||||||
|
"baichuan", # SPM
|
||||||
"bert-bge", # WPM
|
"bert-bge", # WPM
|
||||||
"jina-v2-en", # WPM
|
"jina-v2-en", # WPM
|
||||||
"llama-bpe", # BPE
|
"llama-bpe", # BPE
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue