Minor + style

This commit is contained in:
jaime-m-p 2024-05-14 23:13:02 +02:00
parent 1714e1a775
commit 9bc5d83502
2 changed files with 35 additions and 31 deletions

View file

@ -15,6 +15,7 @@ class CoodepointFlags (ctypes.Structure):
("is_control", ctypes.c_uint16, 1), # regex: \p{C} ("is_control", ctypes.c_uint16, 1), # regex: \p{C}
] ]
assert (ctypes.sizeof(CoodepointFlags) == 2) assert (ctypes.sizeof(CoodepointFlags) == 2)

View file

@ -6,6 +6,7 @@
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe # python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
# #
import os
import time import time
import logging import logging
import argparse import argparse
@ -15,7 +16,7 @@ import random
from typing import Callable, Iterator from typing import Callable, Iterator
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe") logger = logging.getLogger("test-tokenizer-random-bpe")
@ -259,9 +260,6 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
i = find_first_mismatch(ids1, ids2) i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1] ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
#assert (text2 in text)
logger.info(" Text: " + repr(text2))
logger.info(" TokenIDs: " + str(ids1)) logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2)) logger.info(" Expected: " + str(ids2))
raise Exception() raise Exception()
@ -269,22 +267,23 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0)) logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
if __name__ == "__main__": def main(argv: list[str] = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", help="path to vocab 'gguf' file") parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args() args = parser.parse_args(argv)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
def func_tokenize2(text: str): def func_tokenize2(text: str):
return tokenizer.encode(text, add_special_tokens=False) return tokenizer.encode(text, add_special_tokens=False)
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens) parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
def func_tokenize1(text: str): def func_tokenize1(text: str):
return model.tokenize(text, add_special=False, parse_special=parse_special) return model.tokenize(text, add_special=False, parse_special=parse_special)
@ -298,3 +297,7 @@ if __name__ == "__main__":
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL # test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
model.free() model.free()
if __name__ == "__main__":
main()