Refactor random tokenizer test

This commit is contained in:
jaime-m-p 2024-05-08 23:30:52 +02:00
parent 70ca1fe204
commit 77cbb79532

View file

@ -1,15 +1,19 @@
# tests with BPE tokenizer
# Test libllama tokenizer == AutoTokenizer.
# Brute force random tokens/text generation.
#
# sample usage:
# Sample usage:
#
# python3 tests/test-tokenizer-0-bpe.py ./models/ggml-vocab-llama-bpe.gguf ~/Data/huggingface/Meta-Llama-3-8B-Instruct/
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
#
import time
import logging
import argparse
import subprocess
import random
from typing import Iterator
import cffi
from transformers import AutoTokenizer, PreTrainedTokenizerBase
@ -61,12 +65,12 @@ class LibLlamaModel:
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
self.lib = libllama.lib
self.ffi = libllama.ffi
if type(mparams) == dict:
if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams)
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
if not self.model:
raise RuntimeError("error: failed to load model '%s'" % path_model)
if type(cparams) == dict:
if isinstance(cparams, dict):
cparams = libllama.context_default_params(**cparams)
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
if not self.ctx:
@ -92,18 +96,9 @@ class LibLlamaModel:
return list(self.token_ids[0:num])
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a,b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
tests = [
def generator_custom_text() -> Iterator[str]:
"""General tests"""
yield from [
"",
" ",
" ",
@ -146,7 +141,10 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
"333333333",
]
more_tests = [
def generator_custom_text_edge_cases() -> Iterator[str]:
"""Edge cases found while debugging"""
yield from [
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
'¼-a', # unicode_ranges_digit, 0x00BC
'½-a', # unicode_ranges_digit, 0x00BD
@ -157,18 +155,9 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
'<s>a' # TODO: Phi-3 fail
]
for text in tests + more_tests:
ids1 = model.tokenize(text, add_special=False, parse_special=False)
ids2 = tokenizer.encode(text, add_special_tokens=False)
logger.info(repr(text))
if ids1 != ids2:
logger.info(" TokenIDs: " + str(list(ids1)))
logger.info(" Expected: " + str(list(ids2)))
logger.info(" Index: %d" % find_first_mismatch(ids1, ids2))
raise Exception()
def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
def generator_random_chars(iterations = 100) -> Iterator[str]:
"""Brute force random text with simple characters"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(set("""
@ -179,13 +168,9 @@ def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
"""))
logger.info("Bruteforce random chars encodings ...")
rand = random.Random()
for m in range(iterations):
logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
@ -193,61 +178,41 @@ def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
word = rand.choices(CHARS, k=k)
space = rand.choice(WHITESPACES)
text.append("".join(word) + space)
text = "".join(text)
ids1 = model.tokenize(text, add_special=False, parse_special=False)
ids2 = tokenizer.encode(text, add_special_tokens=False)
assert(ids1 == ids2)
yield "".join(text)
def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
"""Brute force random text with vocab characters"""
logger.info("Building vocab char list ...")
vocab_ids = list(tokenizer.vocab.values())
vocab_text = tokenizer.decode(vocab_ids)
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
vocab_chars = list(set(vocab_text))
del vocab_ids, vocab_text
logger.info("Bruteforce random text encodings ...")
rand = random.Random()
for m in range(iterations):
logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m)
text = rand.choices(vocab_chars, k=1024)
text = "".join(text)
ids1 = model.tokenize(text, add_special=False, parse_special=False)
ids2 = tokenizer.encode(text, add_special_tokens=False)
assert(ids1 == ids2)
yield "".join(text)
def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
"""Brute force random text from vocab tokens"""
logger.info("Building token list ...")
space_id = tokenizer.encode(" ")[0]
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
vocab_ids = list(tokenizer.vocab.values())
vocab_ids = list(sorted(vocab_ids + vocab_ids))
for i in range(1, len(vocab_ids), 2):
vocab_ids[i] = space_id
vocab_tokens = tokenizer.decode(vocab_ids)
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
vocab_tokens = vocab_tokens.split(" ")
del vocab_ids
logger.info("Checking single token encodings ...")
for token in vocab_tokens:
ids1 = model.tokenize(token, parse_special=True)
ids2 = tokenizer.encode(token)
assert(ids1 == ids2)
yield from vocab_tokens
logger.info("Bruteforce random text encodings ...")
rand = random.Random()
for m in range(iterations):
logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
@ -256,24 +221,17 @@ def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenize
tokens = [t.strip(" \n\r\t") for t in tokens]
sep = rand.choice(" \n\r\t")
text.append("".join(tokens) + sep)
text = "".join(text)
ids1 = model.tokenize(text, add_special=False, parse_special=False)
ids2 = tokenizer.encode(text, add_special_tokens=False)
assert(ids1 == ids2)
yield "".join(text)
def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
def generator_random_bytes(iterations = 100) -> Iterator[str]:
"""Brute force random bytes"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
logger.info("Bruteforce random bytes encodings ...")
rand = random.Random()
for m in range(iterations):
logger.debug("%d/%d" % (m + 1, iterations))
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
@ -281,11 +239,36 @@ def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
word = [chr(r) for r in rand.randbytes(k) if r]
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
text = "".join(text)
yield "".join(text)
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a,b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
t0 = time.perf_counter()
logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator:
ids1 = model.tokenize(text, add_special=False, parse_special=False)
ids2 = tokenizer.encode(text, add_special_tokens=False)
assert(ids1 == ids2)
if ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
assert (text2 in text)
logger.info(" Text: " + repr(text2))
logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2))
raise Exception()
t1 = time.perf_counter()
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
if __name__ == "__main__":
@ -302,10 +285,11 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
test_custom_texts(model, tokenizer)
test_random_chars(model, tokenizer, 10_000)
test_random_vocab_chars(model, tokenizer, 10_000)
test_random_vocab_tokens(model, tokenizer, 10_000)
# test_random_bytes(model, tokenizer, 10_000) # FAIL
test_compare_tokenizer(model, tokenizer, generator_custom_text())
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
model.free()