Update and bugfix brute force random test
This commit is contained in:
parent
e44e608239
commit
bb205eeff8
1 changed files with 39 additions and 34 deletions
|
@ -1,5 +1,5 @@
|
||||||
# Test libllama tokenizer == AutoTokenizer.
|
# Test libllama tokenizer == AutoTokenizer.
|
||||||
# Brute force random tokens/text generation.
|
# Brute force random words/text generation.
|
||||||
#
|
#
|
||||||
# Sample usage:
|
# Sample usage:
|
||||||
#
|
#
|
||||||
|
@ -12,7 +12,7 @@ import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import Iterator
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
@ -152,10 +152,17 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
'a 〇b', # unicode_ranges_digit, 0x3007
|
'a 〇b', # unicode_ranges_digit, 0x3007
|
||||||
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
||||||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||||
'<s>a' # TODO: Phi-3 fail
|
'Cửa Việt', # llama-3, ignore_merges = true
|
||||||
|
'<s>a', # TODO: Phi-3 fail
|
||||||
|
'a\na', # TODO: Bert fail
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
|
||||||
|
"""Brute force check all vocab words"""
|
||||||
|
yield from vocab
|
||||||
|
|
||||||
|
|
||||||
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||||
"""Brute force random text with simple characters"""
|
"""Brute force random text with simple characters"""
|
||||||
|
|
||||||
|
@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator[str]:
|
||||||
"""Brute force random text with vocab characters"""
|
"""Brute force random text with vocab characters"""
|
||||||
|
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
vocab_chars = set()
|
||||||
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
for word in vocab:
|
||||||
vocab_chars = list(set(vocab_text))
|
vocab_chars.update(word)
|
||||||
del vocab_ids, vocab_text
|
vocab_chars = list(vocab_chars)
|
||||||
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator[str]:
|
||||||
"""Brute force random text from vocab tokens"""
|
"""Brute force random text from vocab words"""
|
||||||
|
|
||||||
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
|
vocab = [w.strip() for w in vocab]
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
yield from vocab
|
||||||
vocab_ids = list(sorted(vocab_ids + vocab_ids))
|
|
||||||
for i in range(1, len(vocab_ids), 2):
|
|
||||||
vocab_ids[i] = space_id
|
|
||||||
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
|
||||||
vocab_tokens = vocab_tokens.split(" ")
|
|
||||||
del vocab_ids
|
|
||||||
|
|
||||||
yield from vocab_tokens
|
|
||||||
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
@ -217,10 +216,9 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
|
||||||
num_words = rand.randint(300, 400)
|
num_words = rand.randint(300, 400)
|
||||||
for i in range(num_words):
|
for i in range(num_words):
|
||||||
k = rand.randint(1, 3)
|
k = rand.randint(1, 3)
|
||||||
tokens = rand.choices(vocab_tokens, k=k)
|
words = rand.choices(vocab, k=k)
|
||||||
tokens = [t.strip(" \n\r\t") for t in tokens]
|
|
||||||
sep = rand.choice(" \n\r\t")
|
sep = rand.choice(" \n\r\t")
|
||||||
text.append("".join(tokens) + sep)
|
text.append("".join(words) + sep)
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
|
@ -242,7 +240,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
|
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
||||||
|
@ -255,8 +253,8 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
ids1 = func_tokenize1(text)
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
ids2 = func_tokenize2(text)
|
||||||
if ids1 != ids2:
|
if ids1 != ids2:
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
||||||
|
@ -281,15 +279,22 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048))
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||||
|
def func_tokenize2(text:str):
|
||||||
|
return tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
test_compare_tokenizer(model, tokenizer, generator_custom_text())
|
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||||
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
|
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
|
def func_tokenize1(text:str):
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
|
return model.tokenize(text, add_special=False, parse_special=parse_special)
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
|
|
||||||
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
|
vocab = tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
|
||||||
|
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
||||||
|
|
||||||
model.free()
|
model.free()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue