Refactor random tokenizer test
This commit is contained in:
parent
70ca1fe204
commit
77cbb79532
1 changed files with 65 additions and 81 deletions
|
@ -1,15 +1,19 @@
|
||||||
# tests with BPE tokenizer
|
# Test libllama tokenizer == AutoTokenizer.
|
||||||
|
# Brute force random tokens/text generation.
|
||||||
#
|
#
|
||||||
# sample usage:
|
# Sample usage:
|
||||||
#
|
#
|
||||||
# python3 tests/test-tokenizer-0-bpe.py ./models/ggml-vocab-llama-bpe.gguf ~/Data/huggingface/Meta-Llama-3-8B-Instruct/
|
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import time
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
@ -30,7 +34,7 @@ class LibLlama:
|
||||||
def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
|
def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
|
||||||
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
|
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
|
||||||
res = subprocess.run(cmd, stdout=subprocess.PIPE)
|
res = subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||||
assert(res.returncode == 0)
|
assert (res.returncode == 0)
|
||||||
source = res.stdout.decode()
|
source = res.stdout.decode()
|
||||||
ffi = cffi.FFI()
|
ffi = cffi.FFI()
|
||||||
if True: # workarounds for pycparser
|
if True: # workarounds for pycparser
|
||||||
|
@ -61,12 +65,12 @@ class LibLlamaModel:
|
||||||
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
|
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
|
||||||
self.lib = libllama.lib
|
self.lib = libllama.lib
|
||||||
self.ffi = libllama.ffi
|
self.ffi = libllama.ffi
|
||||||
if type(mparams) == dict:
|
if isinstance(mparams, dict):
|
||||||
mparams = libllama.model_default_params(**mparams)
|
mparams = libllama.model_default_params(**mparams)
|
||||||
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
|
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
|
||||||
if not self.model:
|
if not self.model:
|
||||||
raise RuntimeError("error: failed to load model '%s'" % path_model)
|
raise RuntimeError("error: failed to load model '%s'" % path_model)
|
||||||
if type(cparams) == dict:
|
if isinstance(cparams, dict):
|
||||||
cparams = libllama.context_default_params(**cparams)
|
cparams = libllama.context_default_params(**cparams)
|
||||||
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
|
self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
|
||||||
if not self.ctx:
|
if not self.ctx:
|
||||||
|
@ -92,18 +96,9 @@ class LibLlamaModel:
|
||||||
return list(self.token_ids[0:num])
|
return list(self.token_ids[0:num])
|
||||||
|
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def generator_custom_text() -> Iterator[str]:
|
||||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
"""General tests"""
|
||||||
if a != b:
|
yield from [
|
||||||
return i
|
|
||||||
if len(ids1) == len(ids2):
|
|
||||||
return -1
|
|
||||||
return min(len(ids1), len(ids2))
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
|
|
||||||
|
|
||||||
tests = [
|
|
||||||
"",
|
"",
|
||||||
" ",
|
" ",
|
||||||
" ",
|
" ",
|
||||||
|
@ -146,7 +141,10 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
|
||||||
"333333333",
|
"333333333",
|
||||||
]
|
]
|
||||||
|
|
||||||
more_tests = [
|
|
||||||
|
def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
|
"""Edge cases found while debugging"""
|
||||||
|
yield from [
|
||||||
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
||||||
'¼-a', # unicode_ranges_digit, 0x00BC
|
'¼-a', # unicode_ranges_digit, 0x00BC
|
||||||
'½-a', # unicode_ranges_digit, 0x00BD
|
'½-a', # unicode_ranges_digit, 0x00BD
|
||||||
|
@ -157,18 +155,9 @@ def test_custom_texts(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase):
|
||||||
'<s>a' # TODO: Phi-3 fail
|
'<s>a' # TODO: Phi-3 fail
|
||||||
]
|
]
|
||||||
|
|
||||||
for text in tests + more_tests:
|
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
logger.info(repr(text))
|
|
||||||
if ids1 != ids2:
|
|
||||||
logger.info(" TokenIDs: " + str(list(ids1)))
|
|
||||||
logger.info(" Expected: " + str(list(ids2)))
|
|
||||||
logger.info(" Index: %d" % find_first_mismatch(ids1, ids2))
|
|
||||||
raise Exception()
|
|
||||||
|
|
||||||
|
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||||
def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
|
"""Brute force random text with simple characters"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
CHARS = list(set("""
|
CHARS = list(set("""
|
||||||
|
@ -179,13 +168,9 @@ def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
|
||||||
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
|
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
logger.info("Bruteforce random chars encodings ...")
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
|
||||||
logger.debug("%d/%d" % (m + 1, iterations))
|
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
|
|
||||||
text = []
|
text = []
|
||||||
num_words = rand.randint(300, 400)
|
num_words = rand.randint(300, 400)
|
||||||
for i in range(num_words):
|
for i in range(num_words):
|
||||||
|
@ -193,61 +178,41 @@ def test_random_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
|
||||||
word = rand.choices(CHARS, k=k)
|
word = rand.choices(CHARS, k=k)
|
||||||
space = rand.choice(WHITESPACES)
|
space = rand.choice(WHITESPACES)
|
||||||
text.append("".join(word) + space)
|
text.append("".join(word) + space)
|
||||||
text = "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
assert(ids1 == ids2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_vocab_chars(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
|
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
||||||
|
"""Brute force random text with vocab characters"""
|
||||||
|
|
||||||
logger.info("Building vocab char list ...")
|
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
vocab_ids = list(tokenizer.vocab.values())
|
||||||
vocab_text = tokenizer.decode(vocab_ids)
|
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
||||||
vocab_chars = list(set(vocab_text))
|
vocab_chars = list(set(vocab_text))
|
||||||
del vocab_ids, vocab_text
|
del vocab_ids, vocab_text
|
||||||
|
|
||||||
logger.info("Bruteforce random text encodings ...")
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
|
||||||
logger.debug("%d/%d" % (m + 1, iterations))
|
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
|
|
||||||
text = rand.choices(vocab_chars, k=1024)
|
text = rand.choices(vocab_chars, k=1024)
|
||||||
text = "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
assert(ids1 == ids2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
|
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
||||||
|
"""Brute force random text from vocab tokens"""
|
||||||
|
|
||||||
logger.info("Building token list ...")
|
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
|
||||||
space_id = tokenizer.encode(" ")[0]
|
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
vocab_ids = list(tokenizer.vocab.values())
|
||||||
vocab_ids = list(sorted(vocab_ids + vocab_ids))
|
vocab_ids = list(sorted(vocab_ids + vocab_ids))
|
||||||
for i in range(1, len(vocab_ids), 2):
|
for i in range(1, len(vocab_ids), 2):
|
||||||
vocab_ids[i] = space_id
|
vocab_ids[i] = space_id
|
||||||
vocab_tokens = tokenizer.decode(vocab_ids)
|
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
||||||
vocab_tokens = vocab_tokens.split(" ")
|
vocab_tokens = vocab_tokens.split(" ")
|
||||||
del vocab_ids
|
del vocab_ids
|
||||||
|
|
||||||
logger.info("Checking single token encodings ...")
|
yield from vocab_tokens
|
||||||
for token in vocab_tokens:
|
|
||||||
ids1 = model.tokenize(token, parse_special=True)
|
|
||||||
ids2 = tokenizer.encode(token)
|
|
||||||
assert(ids1 == ids2)
|
|
||||||
|
|
||||||
logger.info("Bruteforce random text encodings ...")
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
|
||||||
logger.debug("%d/%d" % (m + 1, iterations))
|
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
|
|
||||||
text = []
|
text = []
|
||||||
num_words = rand.randint(300, 400)
|
num_words = rand.randint(300, 400)
|
||||||
for i in range(num_words):
|
for i in range(num_words):
|
||||||
|
@ -256,24 +221,17 @@ def test_random_vocab_tokens(model: LibLlamaModel, tokenizer: PreTrainedTokenize
|
||||||
tokens = [t.strip(" \n\r\t") for t in tokens]
|
tokens = [t.strip(" \n\r\t") for t in tokens]
|
||||||
sep = rand.choice(" \n\r\t")
|
sep = rand.choice(" \n\r\t")
|
||||||
text.append("".join(tokens) + sep)
|
text.append("".join(tokens) + sep)
|
||||||
text = "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
assert(ids1 == ids2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, iterations = 100):
|
def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
||||||
|
"""Brute force random bytes"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
|
|
||||||
logger.info("Bruteforce random bytes encodings ...")
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
|
|
||||||
logger.debug("%d/%d" % (m + 1, iterations))
|
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
|
|
||||||
text = []
|
text = []
|
||||||
num_words = rand.randint(300, 400)
|
num_words = rand.randint(300, 400)
|
||||||
for i in range(num_words):
|
for i in range(num_words):
|
||||||
|
@ -281,11 +239,36 @@ def test_random_bytes(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase,
|
||||||
word = [chr(r) for r in rand.randbytes(k) if r]
|
word = [chr(r) for r in rand.randbytes(k) if r]
|
||||||
word.append(rand.choice(WHITESPACES))
|
word.append(rand.choice(WHITESPACES))
|
||||||
text.append("".join(word))
|
text.append("".join(word))
|
||||||
text = "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
|
||||||
|
|
||||||
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
|
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
||||||
|
if a != b:
|
||||||
|
return i
|
||||||
|
if len(ids1) == len(ids2):
|
||||||
|
return -1
|
||||||
|
return min(len(ids1), len(ids2))
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
|
for text in generator:
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
||||||
assert(ids1 == ids2)
|
if ids1 != ids2:
|
||||||
|
i = find_first_mismatch(ids1, ids2)
|
||||||
|
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
||||||
|
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
||||||
|
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
|
||||||
|
assert (text2 in text)
|
||||||
|
logger.info(" Text: " + repr(text2))
|
||||||
|
logger.info(" TokenIDs: " + str(ids1))
|
||||||
|
logger.info(" Expected: " + str(ids2))
|
||||||
|
raise Exception()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -302,10 +285,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||||
|
|
||||||
test_custom_texts(model, tokenizer)
|
test_compare_tokenizer(model, tokenizer, generator_custom_text())
|
||||||
test_random_chars(model, tokenizer, 10_000)
|
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
|
||||||
test_random_vocab_chars(model, tokenizer, 10_000)
|
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
|
||||||
test_random_vocab_tokens(model, tokenizer, 10_000)
|
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
|
||||||
# test_random_bytes(model, tokenizer, 10_000) # FAIL
|
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
|
||||||
|
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
|
||||||
|
|
||||||
model.free()
|
model.free()
|
Loading…
Add table
Add a link
Reference in a new issue