Update brute force random test

This commit is contained in:
jaime-m-p 2024-06-14 20:14:29 +02:00
parent 0575023923
commit 8cda5af9fe

View file

@ -11,13 +11,15 @@ import logging
import argparse import argparse
import subprocess import subprocess
import random import random
import unicodedata
from typing import Callable, Iterator from typing import Callable, Iterator
import cffi import cffi
from transformers import AutoTokenizer from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe")
logger = logging.getLogger("test-tokenizer-random")
class LibLlama: class LibLlama:
@ -157,6 +159,7 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'<unk><|endoftext|><s>', # Phi-3 fail '<unk><|endoftext|><s>', # Phi-3 fail
'a\na', # bert fail 'a\na', # bert fail
'"`', # falcon '"`', # falcon
' \u2e4e', # falcon
'a\xa0\xa0\x00b', # jina-v2-es 'a\xa0\xa0\x00b', # jina-v2-es
'one <mask>', # jina-v2-es <mask> lstrip=true 'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3 'a </s> b', # rstrip phi-3
@ -231,19 +234,32 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
yield "".join(text) yield "".join(text)
def generator_unicodes() -> Iterator[str]:
"""Iterate unicode characters"""
MAX_CODEPOINTS = 0x30000 # 0x110000
def _valid(cpt):
if cpt >= 0x30000: # unassigned and supplement­ary
return False
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
return False
if unicodedata.category(chr(cpt)) == "Cn":
return False
return True
characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
yield from characters
def generator_random_unicodes(iterations=100) -> Iterator[str]: def generator_random_unicodes(iterations=100) -> Iterator[str]:
"""Brute force random text with unicode characters""" """Brute force random text with unicode characters"""
NUM_WORDS = 200 NUM_WORDS = 200
MAX_CODEPOINTS = 0x110000
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
def _valid(c): characters = list(generator_unicodes())
if 0x0D800 <= c <= 0x0DFFF: return False # surrogates
if 0x0D800 <= c <= 0x0DFFF: return False # unassigned
if 0x60000 <= c <= 0x6FFFF: return False # unassigned
if 0x80000 <= c <= 0x8FFFF: return False # unassigned
return True
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -251,8 +267,7 @@ def generator_random_unicodes(iterations=100) -> Iterator[str]:
text = [] text = []
for _ in range(NUM_WORDS): for _ in range(NUM_WORDS):
k = rand.randint(1, 7) k = rand.randint(1, 7)
word = [rand.randrange(1, MAX_CODEPOINTS) for _ in range(k)] word = rand.choices(characters, k=k)
word = [chr(c) for c in word if _valid(c)]
word.append(rand.choice(WHITESPACES)) word.append(rand.choice(WHITESPACES))
text.append("".join(word)) text.append("".join(word))
yield "".join(text) yield "".join(text)
@ -305,9 +320,11 @@ def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, gener
t_tokenizer1 = 0 t_tokenizer1 = 0
t_tokenizer2 = 0 t_tokenizer2 = 0
t_start = time.perf_counter() t_start = time.perf_counter()
num_errors = 10
logger.info("%s: %s" % (generator.__name__, "ini")) logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator: for text in generator:
# print(repr(text), hex(ord(text[0])), text.encode())
t0 = time.perf_counter() t0 = time.perf_counter()
ids1 = func_tokenize1(text) ids1 = func_tokenize1(text)
t1 = time.perf_counter() t1 = time.perf_counter()
@ -319,9 +336,12 @@ def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, gener
i = find_first_mismatch(ids1, ids2) i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.info(" TokenIDs: " + str(ids1)) logger.error(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2)) logger.error(" Expected: " + str(ids2))
raise Exception() # raise Exception()
num_errors += 1
if num_errors > 10:
break
t_total = time.perf_counter() - t_start t_total = time.perf_counter() - t_start
logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total)) logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
@ -334,7 +354,8 @@ def main(argv: list[str] = None):
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args(argv) args = parser.parse_args(argv)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
logger.info(f"VOCABFILE: '{args.vocab_file}'")
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
@ -356,6 +377,7 @@ def main(argv: list[str] = None):
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text()) compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
@ -370,16 +392,24 @@ def main(argv: list[str] = None):
if __name__ == "__main__": if __name__ == "__main__":
# main() # main()
logging.basicConfig(
level = logging.DEBUG,
format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S",
filename = logger.name + ".log",
filemode = "a"
)
path_tokenizers = "./models/tokenizers/" path_tokenizers = "./models/tokenizers/"
path_vocab_format = "./models/ggml-vocab-%s.gguf" path_vocab_format = "./models/ggml-vocab-%s.gguf"
# import os # import os
# tokenizers = os.listdir(path_tokenizers) # tokenizers = os.listdir(path_tokenizers)
tokenizers = [ tokenizers = [
"llama-spm", # SPM # "llama-spm", # SPM
"phi-3", # SPM # "phi-3", # SPM
"bert-bge", # WPM # "bert-bge", # WPM
"jina-v2-en", # WPM # "jina-v2-en", # WPM
"gpt-2", # BPE "gpt-2", # BPE
"llama-bpe", # BPE "llama-bpe", # BPE
"falcon", # BPE "falcon", # BPE
@ -394,7 +424,8 @@ if __name__ == "__main__":
] ]
for tokenizer in tokenizers: for tokenizer in tokenizers:
print("\n" + "=" * 50 + "\n" + tokenizer + "\n") # noqa logger.info("=" * 50)
logger.info(f"TOKENIZER: '{tokenizer}'")
vocab_file = path_vocab_format % tokenizer vocab_file = path_vocab_format % tokenizer
dir_tokenizer = path_tokenizers + "/" + tokenizer dir_tokenizer = path_tokenizers + "/" + tokenizer
main([vocab_file, dir_tokenizer, "--verbose"]) main([vocab_file, dir_tokenizer, "--verbose"])