py : type-check all Python scripts with Pyright

This commit is contained in:
Francis Couture-Harpin 2024-07-06 11:36:28 -04:00
parent 87e25a1d1b
commit e29fd9634c
35 changed files with 264 additions and 136 deletions

View file

@ -6,6 +6,8 @@
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
#
from __future__ import annotations
import time
import logging
import argparse
@ -13,7 +15,8 @@ import subprocess
import random
import unicodedata
from typing import Iterator
from typing import Any, Iterator, cast
from typing_extensions import Buffer
import cffi
from transformers import AutoTokenizer
@ -28,14 +31,14 @@ class LibLlama:
DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
self.lib.llama_backend_init()
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
cmd += ["-I" + path for path in path_includes] + [path_llama_h]
res = subprocess.run(cmd, stdout=subprocess.PIPE)
@ -68,7 +71,7 @@ class LibLlama:
class LibLlamaModel:
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
self.lib = libllama.lib
self.lib: Any = libllama.lib
self.ffi = libllama.ffi
if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams)
@ -94,11 +97,11 @@ class LibLlamaModel:
self.lib = None
def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
text = text.encode("utf-8")
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
encoded_text: bytes = text.encode("utf-8")
num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
while num < 0 and len(self.token_ids) < (16 << 20):
self.token_ids = self.ffi.new("llama_token[]", -2 * num)
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
return list(self.token_ids[0:num])
def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
@ -110,7 +113,7 @@ class LibLlamaModel:
while num < 0 and len(self.text_buff) < (16 << 20):
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
class Tokenizer:
@ -152,7 +155,7 @@ class TokenizerGroundtruth (Tokenizer):
class TokenizerLlamaCpp (Tokenizer):
libllama: LibLlama = None
libllama: LibLlama | None = None
def __init__(self, vocab_file: str):
if not self.libllama:
@ -404,7 +407,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
@ -433,7 +436,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
decode_errors = 0
MAX_ERRORS = 10
logger.info("%s: %s" % (generator.__name__, "ini"))
logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator:
# print(repr(text), text.encode())
# print(repr(text), hex(ord(text[0])), text.encode())
@ -472,10 +475,10 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
break
t_total = time.perf_counter() - t_start
logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def main(argv: list[str] = None):
def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")