diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 04a58a3ec..15ca5bf31 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -107,7 +107,7 @@ class LibLlamaModel: while num < 0 and len(self.text_buff) < (16 << 20): self.text_buff = self.ffi.new("uint8_t[]", -2 * num) num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special) - return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8") + return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' class Tokenizer: @@ -144,7 +144,7 @@ class TokenizerGroundtruth (Tokenizer): return self.model.encode(text, add_special_tokens=True) def decode(self, ids: list[int]) -> str: - return self.model.decode(ids, skip_special_tokens=True) + return self.model.decode(ids, skip_special_tokens=False) class TokenizerLlamaCpp (Tokenizer): @@ -160,7 +160,7 @@ class TokenizerLlamaCpp (Tokenizer): return self.model.tokenize(text, add_special=True, parse_special=True) def decode(self, ids: list[int]) -> str: - return self.model.detokenize(ids, special=False) + return self.model.detokenize(ids, special=True) def generator_custom_text() -> Iterator[str]: @@ -232,6 +232,9 @@ def generator_custom_text_edge_cases() -> Iterator[str]: '\xa0aC', # deepseek '\u2029 \uA3E4', # deepseek-llm "a ?", + 'å', # mpt + '\U000ac517', # utf-8 encode error, falcon + '\U000522f4', # utf-8 encode error, starcoder ] @@ -265,7 +268,7 @@ def generator_apostrophe() -> Iterator[str]: def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]: - WHITESPACES = ["", " ", " ", " "] + WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"] all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens))) for token in all_tokens: for lstrip in WHITESPACES: @@ -329,11 +332,9 @@ def generator_unicodes() -> Iterator[str]: def _valid(cpt): if cpt >= 0x30000: # unassigned and supplement­ary return False - if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates - return False # if cpt == 0x2029: # deepseek-llm # return False - if unicodedata.category(chr(cpt)) == "Cn": + if unicodedata.category(chr(cpt)) in ( "Cn", "Cs", "Co" ): # undefined, surrogates, private return False return True @@ -396,7 +397,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 yield "".join(text) -def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: Iterator[str]): +def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def find_first_mismatch(ids1: list[int], ids2: list[int]): for i, (a, b) in enumerate(zip(ids1, ids2)): @@ -406,12 +407,25 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: return -1 return min(len(ids1), len(ids2)) + def check_detokenizer(text: str, text1: str, text2: str) -> bool: + if text1 == text2: # equal to TokenizerGroundtruth? + return True + # equal to source text? + if tokenizer1.add_bos_token: # remove BOS + if text2.startswith(tokenizer1.bos_token): + text2 = text2[len(tokenizer1.bos_token):] + if tokenizer1.add_eos_token: # remove EOS + if text2.endswith(tokenizer1.eos_token): + text2 = text2[:-len(tokenizer1.eos_token)] + return text == text2 + t_encode1 = 0 t_encode2 = 0 t_decode1 = 0 t_decode2 = 0 t_start = time.perf_counter() - num_errors = 0 + encode_errors = 0 + decode_errors = 0 logger.info("%s: %s" % (generator.__name__, "ini")) for text in generator: @@ -424,7 +438,7 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: t2 = time.perf_counter() text1 = tokenizer1.decode(ids1) t3 = time.perf_counter() - text2 = tokenizer2.decode(ids2) + text2 = tokenizer2.decode(ids1) t4 = time.perf_counter() t_encode1 += t1 - t0 t_encode2 += t2 - t1 @@ -436,16 +450,18 @@ def compare_tokenizers(tokenizer1: Tokenizer, tokenizer2: Tokenizer, generator: ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] logger.error(" Expected: " + str(ids1)) logger.error(" Result: " + str(ids2)) - num_errors += 1 - if text1 != text2 and text != text2: + encode_errors += 1 + logger.error(f" {encode_errors=}") + if not check_detokenizer(text, text1, text2): i = find_first_mismatch(text1, text2) text1 = list(text1[max(0, i - 2) : i + 5 + 1]) text2 = list(text2[max(0, i - 2) : i + 5 + 1]) logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) - num_errors += 1 - if num_errors >= 10: - logger.error(f" EXIT: {num_errors=}") + decode_errors += 1 + logger.error(f" {decode_errors=}") + if encode_errors >= 10 or decode_errors >= 10: + logger.error(f" EXIT: {encode_errors=} {decode_errors=}") # raise Exception() break @@ -504,6 +520,7 @@ if __name__ == "__main__": tokenizers = [ "llama-spm", # SPM "phi-3", # SPM + "baichuan", # SPM "bert-bge", # WPM "jina-v2-en", # WPM "llama-bpe", # BPE