diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index c6cdcb554..04a5d302e 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 +# # Test libllama tokenizer == AutoTokenizer. # Brute force random words/text generation. # @@ -11,20 +13,24 @@ from __future__ import annotations import time import logging import argparse +import shutil import subprocess import random import unicodedata - from pathlib import Path -from typing import Any, Iterator, cast +from typing import Any, Iterator, cast, Sequence from typing_extensions import Buffer - +# +# External Imports import cffi from transformers import AutoTokenizer, PreTrainedTokenizer - +# logger = logging.getLogger("test-tokenizer-random") +if shutil.which("gcc") is None: + raise EnvironmentError("GCC is not available on this system. Please install GCC or use preprocessed headers.") + class LibLlama: @@ -408,76 +414,84 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): + try: + # def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): + # for i, (a, b) in enumerate(zip(ids1, ids2)): + # if a != b: + # return i + # if len(ids1) == len(ids2): + # return -1 + # return min(len(ids1), len(ids2)) + # Rewritten to use zip() and next() instead of for loop + def find_first_mismatch(ids1: Sequence[Any], ids2: Sequence[Any]) -> int: + index = next((i for i, (a, b) in enumerate(zip(ids1, ids2)) if a != b), -1) + if index < 0 and len(ids1) != len(ids2): + index = min(len(ids1), len(ids2)) + return index - def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): - for i, (a, b) in enumerate(zip(ids1, ids2)): - if a != b: - return i - if len(ids1) == len(ids2): - return -1 - return min(len(ids1), len(ids2)) + def check_detokenizer(text: str, text1: str, text2: str) -> bool: + if text1 == text2: # equal to TokenizerGroundtruth? + return True + # equal to source text? + if tokenizer1.add_bos_token: # remove BOS + if text2.startswith(tokenizer1.bos_token): + text2 = text2[len(tokenizer1.bos_token):] + if tokenizer1.add_eos_token: # remove EOS + if text2.endswith(tokenizer1.eos_token): + text2 = text2[:-len(tokenizer1.eos_token)] + return text == text2 - def check_detokenizer(text: str, text1: str, text2: str) -> bool: - if text1 == text2: # equal to TokenizerGroundtruth? - return True - # equal to source text? - if tokenizer1.add_bos_token: # remove BOS - if text2.startswith(tokenizer1.bos_token): - text2 = text2[len(tokenizer1.bos_token):] - if tokenizer1.add_eos_token: # remove EOS - if text2.endswith(tokenizer1.eos_token): - text2 = text2[:-len(tokenizer1.eos_token)] - return text == text2 + t_encode1 = 0 + t_encode2 = 0 + t_decode1 = 0 + t_decode2 = 0 + t_start = time.perf_counter() + encode_errors = 0 + decode_errors = 0 + MAX_ERRORS = 10 - t_encode1 = 0 - t_encode2 = 0 - t_decode1 = 0 - t_decode2 = 0 - t_start = time.perf_counter() - encode_errors = 0 - decode_errors = 0 - MAX_ERRORS = 10 + logger.info("%s: %s" % (generator.__qualname__, "ini")) + for text in generator: + # print(repr(text), text.encode()) + # print(repr(text), hex(ord(text[0])), text.encode()) + t0 = time.perf_counter() + ids1 = tokenizer1.encode(text) + t1 = time.perf_counter() + ids2 = tokenizer2.encode(text) + t2 = time.perf_counter() + text1 = tokenizer1.decode(ids1) + t3 = time.perf_counter() + text2 = tokenizer2.decode(ids1) + t4 = time.perf_counter() + t_encode1 += t1 - t0 + t_encode2 += t2 - t1 + t_decode1 += t3 - t2 + t_decode2 += t4 - t3 + if encode_errors < MAX_ERRORS and ids1 != ids2: + i = find_first_mismatch(ids1, ids2) + ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] + ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] + logger.error(" Expected: " + str(ids1)) + logger.error(" Result: " + str(ids2)) + encode_errors += 1 + logger.error(f" {encode_errors=}") + if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): + i = find_first_mismatch(text1, text2) + text1 = list(text1[max(0, i - 2) : i + 5 + 1]) + text2 = list(text2[max(0, i - 2) : i + 5 + 1]) + logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) + logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) + decode_errors += 1 + logger.error(f" {decode_errors=}") + if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: + logger.error(f" EXIT: {encode_errors=} {decode_errors=}") + # raise Exception() + break - logger.info("%s: %s" % (generator.__qualname__, "ini")) - for text in generator: - # print(repr(text), text.encode()) - # print(repr(text), hex(ord(text[0])), text.encode()) - t0 = time.perf_counter() - ids1 = tokenizer1.encode(text) - t1 = time.perf_counter() - ids2 = tokenizer2.encode(text) - t2 = time.perf_counter() - text1 = tokenizer1.decode(ids1) - t3 = time.perf_counter() - text2 = tokenizer2.decode(ids1) - t4 = time.perf_counter() - t_encode1 += t1 - t0 - t_encode2 += t2 - t1 - t_decode1 += t3 - t2 - t_decode2 += t4 - t3 - if encode_errors < MAX_ERRORS and ids1 != ids2: - i = find_first_mismatch(ids1, ids2) - ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] - ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] - logger.error(" Expected: " + str(ids1)) - logger.error(" Result: " + str(ids2)) - encode_errors += 1 - logger.error(f" {encode_errors=}") - if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): - i = find_first_mismatch(text1, text2) - text1 = list(text1[max(0, i - 2) : i + 5 + 1]) - text2 = list(text2[max(0, i - 2) : i + 5 + 1]) - logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) - logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2)) - decode_errors += 1 - logger.error(f" {decode_errors=}") - if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: - logger.error(f" EXIT: {encode_errors=} {decode_errors=}") - # raise Exception() - break - - t_total = time.perf_counter() - t_start - logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") + t_total = time.perf_counter() - t_start + logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") + except Exception as e: + logger.exception(f"An error occurred during tokenizer comparison: {e}") def main(argv: list[str] | None = None): @@ -485,6 +499,9 @@ def main(argv: list[str] | None = None): parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("--max-errors", type=int, default=10, help="Maximum number of errors before stopping") + parser.add_argument("--iterations", type=int, default=100, help="Number of iterations for random generators") + parser.add_argument("--tokenizers", nargs="+", help="List of tokenizers to test", default=tokenizers) args = parser.parse_args(argv) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)