This commit is contained in:
Robert 2025-01-29 10:21:11 -05:00 committed by GitHub
commit 710e2f9dc5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,3 +1,5 @@
#!/usr/bin/env python3
#
# Test libllama tokenizer == AutoTokenizer. # Test libllama tokenizer == AutoTokenizer.
# Brute force random words/text generation. # Brute force random words/text generation.
# #
@ -11,20 +13,24 @@ from __future__ import annotations
import time import time
import logging import logging
import argparse import argparse
import shutil
import subprocess import subprocess
import random import random
import unicodedata import unicodedata
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, cast from typing import Any, Iterator, cast, Sequence
from typing_extensions import Buffer from typing_extensions import Buffer
#
# External Imports
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
#
logger = logging.getLogger("test-tokenizer-random") logger = logging.getLogger("test-tokenizer-random")
if shutil.which("gcc") is None:
raise EnvironmentError("GCC is not available on this system. Please install GCC or use preprocessed headers.")
class LibLlama: class LibLlama:
@ -408,14 +414,20 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
try:
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): # def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)): # for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b: # if a != b:
return i # return i
if len(ids1) == len(ids2): # if len(ids1) == len(ids2):
return -1 # return -1
return min(len(ids1), len(ids2)) # return min(len(ids1), len(ids2))
# Rewritten to use zip() and next() instead of for loop
def find_first_mismatch(ids1: Sequence[Any], ids2: Sequence[Any]) -> int:
index = next((i for i, (a, b) in enumerate(zip(ids1, ids2)) if a != b), -1)
if index < 0 and len(ids1) != len(ids2):
index = min(len(ids1), len(ids2))
return index
def check_detokenizer(text: str, text1: str, text2: str) -> bool: def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth? if text1 == text2: # equal to TokenizerGroundtruth?
@ -478,6 +490,8 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_total = time.perf_counter() - t_start t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
except Exception as e:
logger.exception(f"An error occurred during tokenizer comparison: {e}")
def main(argv: list[str] | None = None): def main(argv: list[str] | None = None):
@ -485,6 +499,9 @@ def main(argv: list[str] | None = None):
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--max-errors", type=int, default=10, help="Maximum number of errors before stopping")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations for random generators")
parser.add_argument("--tokenizers", nargs="+", help="List of tokenizers to test", default=tokenizers)
args = parser.parse_args(argv) args = parser.parse_args(argv)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)