This commit is contained in:
Robert 2025-01-29 10:21:11 -05:00 committed by GitHub
commit 710e2f9dc5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,3 +1,5 @@
#!/usr/bin/env python3
#
# Test libllama tokenizer == AutoTokenizer. # Test libllama tokenizer == AutoTokenizer.
# Brute force random words/text generation. # Brute force random words/text generation.
# #
@ -11,20 +13,24 @@ from __future__ import annotations
import time import time
import logging import logging
import argparse import argparse
import shutil
import subprocess import subprocess
import random import random
import unicodedata import unicodedata
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, cast from typing import Any, Iterator, cast, Sequence
from typing_extensions import Buffer from typing_extensions import Buffer
#
# External Imports
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
#
logger = logging.getLogger("test-tokenizer-random") logger = logging.getLogger("test-tokenizer-random")
if shutil.which("gcc") is None:
raise EnvironmentError("GCC is not available on this system. Please install GCC or use preprocessed headers.")
class LibLlama: class LibLlama:
@ -408,76 +414,84 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
try:
# def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
# for i, (a, b) in enumerate(zip(ids1, ids2)):
# if a != b:
# return i
# if len(ids1) == len(ids2):
# return -1
# return min(len(ids1), len(ids2))
# Rewritten to use zip() and next() instead of for loop
def find_first_mismatch(ids1: Sequence[Any], ids2: Sequence[Any]) -> int:
index = next((i for i, (a, b) in enumerate(zip(ids1, ids2)) if a != b), -1)
if index < 0 and len(ids1) != len(ids2):
index = min(len(ids1), len(ids2))
return index
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): def check_detokenizer(text: str, text1: str, text2: str) -> bool:
for i, (a, b) in enumerate(zip(ids1, ids2)): if text1 == text2: # equal to TokenizerGroundtruth?
if a != b: return True
return i # equal to source text?
if len(ids1) == len(ids2): if tokenizer1.add_bos_token: # remove BOS
return -1 if text2.startswith(tokenizer1.bos_token):
return min(len(ids1), len(ids2)) text2 = text2[len(tokenizer1.bos_token):]
if tokenizer1.add_eos_token: # remove EOS
if text2.endswith(tokenizer1.eos_token):
text2 = text2[:-len(tokenizer1.eos_token)]
return text == text2
def check_detokenizer(text: str, text1: str, text2: str) -> bool: t_encode1 = 0
if text1 == text2: # equal to TokenizerGroundtruth? t_encode2 = 0
return True t_decode1 = 0
# equal to source text? t_decode2 = 0
if tokenizer1.add_bos_token: # remove BOS t_start = time.perf_counter()
if text2.startswith(tokenizer1.bos_token): encode_errors = 0
text2 = text2[len(tokenizer1.bos_token):] decode_errors = 0
if tokenizer1.add_eos_token: # remove EOS MAX_ERRORS = 10
if text2.endswith(tokenizer1.eos_token):
text2 = text2[:-len(tokenizer1.eos_token)]
return text == text2
t_encode1 = 0 logger.info("%s: %s" % (generator.__qualname__, "ini"))
t_encode2 = 0 for text in generator:
t_decode1 = 0 # print(repr(text), text.encode())
t_decode2 = 0 # print(repr(text), hex(ord(text[0])), text.encode())
t_start = time.perf_counter() t0 = time.perf_counter()
encode_errors = 0 ids1 = tokenizer1.encode(text)
decode_errors = 0 t1 = time.perf_counter()
MAX_ERRORS = 10 ids2 = tokenizer2.encode(text)
t2 = time.perf_counter()
text1 = tokenizer1.decode(ids1)
t3 = time.perf_counter()
text2 = tokenizer2.decode(ids1)
t4 = time.perf_counter()
t_encode1 += t1 - t0
t_encode2 += t2 - t1
t_decode1 += t3 - t2
t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
encode_errors += 1
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
logger.info("%s: %s" % (generator.__qualname__, "ini")) t_total = time.perf_counter() - t_start
for text in generator: logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
# print(repr(text), text.encode()) except Exception as e:
# print(repr(text), hex(ord(text[0])), text.encode()) logger.exception(f"An error occurred during tokenizer comparison: {e}")
t0 = time.perf_counter()
ids1 = tokenizer1.encode(text)
t1 = time.perf_counter()
ids2 = tokenizer2.encode(text)
t2 = time.perf_counter()
text1 = tokenizer1.decode(ids1)
t3 = time.perf_counter()
text2 = tokenizer2.decode(ids1)
t4 = time.perf_counter()
t_encode1 += t1 - t0
t_encode2 += t2 - t1
t_decode1 += t3 - t2
t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
encode_errors += 1
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def main(argv: list[str] | None = None): def main(argv: list[str] | None = None):
@ -485,6 +499,9 @@ def main(argv: list[str] | None = None):
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--max-errors", type=int, default=10, help="Maximum number of errors before stopping")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations for random generators")
parser.add_argument("--tokenizers", nargs="+", help="List of tokenizers to test", default=tokenizers)
args = parser.parse_args(argv) args = parser.parse_args(argv)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)