Minor + style
This commit is contained in:
parent
1714e1a775
commit
9bc5d83502
2 changed files with 35 additions and 31 deletions
|
@ -15,7 +15,8 @@ class CoodepointFlags (ctypes.Structure):
|
||||||
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
||||||
]
|
]
|
||||||
|
|
||||||
assert(ctypes.sizeof(CoodepointFlags) == 2)
|
|
||||||
|
assert (ctypes.sizeof(CoodepointFlags) == 2)
|
||||||
|
|
||||||
|
|
||||||
MAX_CODEPOINTS = 0x110000
|
MAX_CODEPOINTS = 0x110000
|
||||||
|
@ -49,7 +50,7 @@ for codepoint in range(MAX_CODEPOINTS):
|
||||||
flags.is_symbol = bool(regex_symbol.match(char))
|
flags.is_symbol = bool(regex_symbol.match(char))
|
||||||
flags.is_control = bool(regex_control.match(char))
|
flags.is_control = bool(regex_control.match(char))
|
||||||
flags.is_undefined = bytes(flags)[0] == 0
|
flags.is_undefined = bytes(flags)[0] == 0
|
||||||
assert(not flags.is_undefined)
|
assert (not flags.is_undefined)
|
||||||
|
|
||||||
# whitespaces
|
# whitespaces
|
||||||
if bool(regex_whitespace.match(char)):
|
if bool(regex_whitespace.match(char)):
|
||||||
|
@ -72,7 +73,7 @@ for codepoint in range(MAX_CODEPOINTS):
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same flags
|
# group ranges with same flags
|
||||||
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||||
for codepoint, flags in enumerate(codepoint_flags):
|
for codepoint, flags in enumerate(codepoint_flags):
|
||||||
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
||||||
ranges_flags.append((codepoint, flags))
|
ranges_flags.append((codepoint, flags))
|
||||||
|
@ -80,7 +81,7 @@ ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same nfd
|
# group ranges with same nfd
|
||||||
ranges_nfd = [(0, 0, 0)] # start, last, nfd
|
ranges_nfd = [(0, 0, 0)] # start, last, nfd
|
||||||
for codepoint, norm in table_nfd:
|
for codepoint, norm in table_nfd:
|
||||||
start = ranges_nfd[-1][0]
|
start = ranges_nfd[-1][0]
|
||||||
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
|
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -15,7 +16,7 @@ import random
|
||||||
from typing import Callable, Iterator
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger("test-tokenizer-random-bpe")
|
logger = logging.getLogger("test-tokenizer-random-bpe")
|
||||||
|
|
||||||
|
@ -145,16 +146,16 @@ def generator_custom_text() -> Iterator[str]:
|
||||||
def generator_custom_text_edge_cases() -> Iterator[str]:
|
def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
"""Edge cases found while debugging"""
|
"""Edge cases found while debugging"""
|
||||||
yield from [
|
yield from [
|
||||||
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
||||||
'¼-a', # unicode_ranges_digit, 0x00BC
|
'¼-a', # unicode_ranges_digit, 0x00BC
|
||||||
'½-a', # unicode_ranges_digit, 0x00BD
|
'½-a', # unicode_ranges_digit, 0x00BD
|
||||||
'¾-a', # unicode_ranges_digit, 0x00BE
|
'¾-a', # unicode_ranges_digit, 0x00BE
|
||||||
'a 〇b', # unicode_ranges_digit, 0x3007
|
'a 〇b', # unicode_ranges_digit, 0x3007
|
||||||
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
||||||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||||
'Cửa Việt', # llama-3, ignore_merges = true
|
'Cửa Việt', # llama-3, ignore_merges = true
|
||||||
'<s>a', # TODO: Phi-3 fail
|
'<s>a', # TODO: Phi-3 fail
|
||||||
'a\na', # TODO: Bert fail
|
'a\na', # TODO: Bert fail
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +164,7 @@ def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
|
||||||
yield from vocab
|
yield from vocab
|
||||||
|
|
||||||
|
|
||||||
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
def generator_random_chars(iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text with simple characters"""
|
"""Brute force random text with simple characters"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
|
@ -188,7 +189,7 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text with vocab characters"""
|
"""Brute force random text with vocab characters"""
|
||||||
|
|
||||||
vocab_chars = set()
|
vocab_chars = set()
|
||||||
|
@ -203,7 +204,7 @@ def generator_random_vocab_chars(vocab: list[str], iterations = 100) -> Iterator
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text from vocab words"""
|
"""Brute force random text from vocab words"""
|
||||||
|
|
||||||
vocab = [w.strip() for w in vocab]
|
vocab = [w.strip() for w in vocab]
|
||||||
|
@ -222,7 +223,7 @@ def generator_random_vocab_words(vocab: list[str], iterations = 100) -> Iterator
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
def generator_random_bytes(iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random bytes"""
|
"""Brute force random bytes"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
|
@ -243,7 +244,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
||||||
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||||
if a != b:
|
if a != b:
|
||||||
return i
|
return i
|
||||||
if len(ids1) == len(ids2):
|
if len(ids1) == len(ids2):
|
||||||
|
@ -259,9 +260,6 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
||||||
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
|
|
||||||
#assert (text2 in text)
|
|
||||||
logger.info(" Text: " + repr(text2))
|
|
||||||
logger.info(" TokenIDs: " + str(ids1))
|
logger.info(" TokenIDs: " + str(ids1))
|
||||||
logger.info(" Expected: " + str(ids2))
|
logger.info(" Expected: " + str(ids2))
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
@ -269,23 +267,24 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
|
||||||
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main(argv: list[str] = None):
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
||||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
|
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||||
def func_tokenize2(text:str):
|
|
||||||
|
def func_tokenize2(text: str):
|
||||||
return tokenizer.encode(text, add_special_tokens=False)
|
return tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
|
||||||
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
|
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
|
||||||
def func_tokenize1(text:str):
|
|
||||||
|
def func_tokenize1(text: str):
|
||||||
return model.tokenize(text, add_special=False, parse_special=parse_special)
|
return model.tokenize(text, add_special=False, parse_special=parse_special)
|
||||||
|
|
||||||
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
||||||
|
@ -298,3 +297,7 @@ if __name__ == "__main__":
|
||||||
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
||||||
|
|
||||||
model.free()
|
model.free()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue