Merge 883dc22d44
into f0d4b29edf
This commit is contained in:
commit
710e2f9dc5
1 changed files with 87 additions and 70 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
# Test libllama tokenizer == AutoTokenizer.
|
# Test libllama tokenizer == AutoTokenizer.
|
||||||
# Brute force random words/text generation.
|
# Brute force random words/text generation.
|
||||||
#
|
#
|
||||||
|
@ -11,20 +13,24 @@ from __future__ import annotations
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterator, cast
|
from typing import Any, Iterator, cast, Sequence
|
||||||
from typing_extensions import Buffer
|
from typing_extensions import Buffer
|
||||||
|
#
|
||||||
|
# External Imports
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
#
|
||||||
|
|
||||||
logger = logging.getLogger("test-tokenizer-random")
|
logger = logging.getLogger("test-tokenizer-random")
|
||||||
|
|
||||||
|
if shutil.which("gcc") is None:
|
||||||
|
raise EnvironmentError("GCC is not available on this system. Please install GCC or use preprocessed headers.")
|
||||||
|
|
||||||
|
|
||||||
class LibLlama:
|
class LibLlama:
|
||||||
|
|
||||||
|
@ -408,76 +414,84 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
||||||
|
|
||||||
|
|
||||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||||
|
try:
|
||||||
|
# def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
||||||
|
# for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||||
|
# if a != b:
|
||||||
|
# return i
|
||||||
|
# if len(ids1) == len(ids2):
|
||||||
|
# return -1
|
||||||
|
# return min(len(ids1), len(ids2))
|
||||||
|
# Rewritten to use zip() and next() instead of for loop
|
||||||
|
def find_first_mismatch(ids1: Sequence[Any], ids2: Sequence[Any]) -> int:
|
||||||
|
index = next((i for i, (a, b) in enumerate(zip(ids1, ids2)) if a != b), -1)
|
||||||
|
if index < 0 and len(ids1) != len(ids2):
|
||||||
|
index = min(len(ids1), len(ids2))
|
||||||
|
return index
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||||
if a != b:
|
return True
|
||||||
return i
|
# equal to source text?
|
||||||
if len(ids1) == len(ids2):
|
if tokenizer1.add_bos_token: # remove BOS
|
||||||
return -1
|
if text2.startswith(tokenizer1.bos_token):
|
||||||
return min(len(ids1), len(ids2))
|
text2 = text2[len(tokenizer1.bos_token):]
|
||||||
|
if tokenizer1.add_eos_token: # remove EOS
|
||||||
|
if text2.endswith(tokenizer1.eos_token):
|
||||||
|
text2 = text2[:-len(tokenizer1.eos_token)]
|
||||||
|
return text == text2
|
||||||
|
|
||||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
t_encode1 = 0
|
||||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
t_encode2 = 0
|
||||||
return True
|
t_decode1 = 0
|
||||||
# equal to source text?
|
t_decode2 = 0
|
||||||
if tokenizer1.add_bos_token: # remove BOS
|
t_start = time.perf_counter()
|
||||||
if text2.startswith(tokenizer1.bos_token):
|
encode_errors = 0
|
||||||
text2 = text2[len(tokenizer1.bos_token):]
|
decode_errors = 0
|
||||||
if tokenizer1.add_eos_token: # remove EOS
|
MAX_ERRORS = 10
|
||||||
if text2.endswith(tokenizer1.eos_token):
|
|
||||||
text2 = text2[:-len(tokenizer1.eos_token)]
|
|
||||||
return text == text2
|
|
||||||
|
|
||||||
t_encode1 = 0
|
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
||||||
t_encode2 = 0
|
for text in generator:
|
||||||
t_decode1 = 0
|
# print(repr(text), text.encode())
|
||||||
t_decode2 = 0
|
# print(repr(text), hex(ord(text[0])), text.encode())
|
||||||
t_start = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
encode_errors = 0
|
ids1 = tokenizer1.encode(text)
|
||||||
decode_errors = 0
|
t1 = time.perf_counter()
|
||||||
MAX_ERRORS = 10
|
ids2 = tokenizer2.encode(text)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
text1 = tokenizer1.decode(ids1)
|
||||||
|
t3 = time.perf_counter()
|
||||||
|
text2 = tokenizer2.decode(ids1)
|
||||||
|
t4 = time.perf_counter()
|
||||||
|
t_encode1 += t1 - t0
|
||||||
|
t_encode2 += t2 - t1
|
||||||
|
t_decode1 += t3 - t2
|
||||||
|
t_decode2 += t4 - t3
|
||||||
|
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
||||||
|
i = find_first_mismatch(ids1, ids2)
|
||||||
|
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
||||||
|
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||||
|
logger.error(" Expected: " + str(ids1))
|
||||||
|
logger.error(" Result: " + str(ids2))
|
||||||
|
encode_errors += 1
|
||||||
|
logger.error(f" {encode_errors=}")
|
||||||
|
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
||||||
|
i = find_first_mismatch(text1, text2)
|
||||||
|
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
||||||
|
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
||||||
|
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
||||||
|
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
||||||
|
decode_errors += 1
|
||||||
|
logger.error(f" {decode_errors=}")
|
||||||
|
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||||
|
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||||
|
# raise Exception()
|
||||||
|
break
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
t_total = time.perf_counter() - t_start
|
||||||
for text in generator:
|
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||||
# print(repr(text), text.encode())
|
except Exception as e:
|
||||||
# print(repr(text), hex(ord(text[0])), text.encode())
|
logger.exception(f"An error occurred during tokenizer comparison: {e}")
|
||||||
t0 = time.perf_counter()
|
|
||||||
ids1 = tokenizer1.encode(text)
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
ids2 = tokenizer2.encode(text)
|
|
||||||
t2 = time.perf_counter()
|
|
||||||
text1 = tokenizer1.decode(ids1)
|
|
||||||
t3 = time.perf_counter()
|
|
||||||
text2 = tokenizer2.decode(ids1)
|
|
||||||
t4 = time.perf_counter()
|
|
||||||
t_encode1 += t1 - t0
|
|
||||||
t_encode2 += t2 - t1
|
|
||||||
t_decode1 += t3 - t2
|
|
||||||
t_decode2 += t4 - t3
|
|
||||||
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
|
||||||
i = find_first_mismatch(ids1, ids2)
|
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
|
||||||
logger.error(" Expected: " + str(ids1))
|
|
||||||
logger.error(" Result: " + str(ids2))
|
|
||||||
encode_errors += 1
|
|
||||||
logger.error(f" {encode_errors=}")
|
|
||||||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
|
||||||
i = find_first_mismatch(text1, text2)
|
|
||||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
|
||||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
|
||||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
|
||||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
|
||||||
decode_errors += 1
|
|
||||||
logger.error(f" {decode_errors=}")
|
|
||||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
|
||||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
|
||||||
# raise Exception()
|
|
||||||
break
|
|
||||||
|
|
||||||
t_total = time.perf_counter() - t_start
|
|
||||||
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv: list[str] | None = None):
|
def main(argv: list[str] | None = None):
|
||||||
|
@ -485,6 +499,9 @@ def main(argv: list[str] | None = None):
|
||||||
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
||||||
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
|
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
|
||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
parser.add_argument("--max-errors", type=int, default=10, help="Maximum number of errors before stopping")
|
||||||
|
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations for random generators")
|
||||||
|
parser.add_argument("--tokenizers", nargs="+", help="List of tokenizers to test", default=tokenizers)
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue