py : type-check all Python scripts with Pyright

This commit is contained in:
Francis Couture-Harpin 2024-07-06 11:36:28 -04:00
parent 87e25a1d1b
commit e29fd9634c
35 changed files with 264 additions and 136 deletions

View file

@ -67,7 +67,7 @@ class ReaderTensor(NamedTuple):
class GGUFReader:
# I - same as host, S - swapped
byte_order: Literal['I'] | Literal['S'] = 'I'
byte_order: Literal['I', 'S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
data_offset: int
@ -86,7 +86,7 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
@ -140,7 +140,7 @@ class GGUFReader:
return self.tensors[idx]
def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)

View file

@ -16,16 +16,16 @@ logger = logging.getLogger(__name__)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
def __getattr__(self, name: str) -> Any:
meta_attr = getattr(self._meta, name)
if callable(meta_attr):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@ -141,19 +141,21 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
def collect_replace(t: LazyBase):
if collect_replace.shared_lazy is None:
collect_replace.shared_lazy = t._lazy
else:
collect_replace.shared_lazy.extend(t._lazy)
t._lazy = collect_replace.shared_lazy
class CollectSharedLazy:
# emulating a static variable
shared_lazy: None | deque[LazyBase] = None
# emulating a static variable
collect_replace.shared_lazy = None
@staticmethod
def collect_replace(t: LazyBase):
if CollectSharedLazy.shared_lazy is None:
CollectSharedLazy.shared_lazy = t._lazy
else:
CollectSharedLazy.shared_lazy.extend(t._lazy)
t._lazy = CollectSharedLazy.shared_lazy
LazyBase._recurse_apply(args, collect_replace)
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
shared_lazy = collect_replace.shared_lazy
shared_lazy = CollectSharedLazy.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else:
@ -184,6 +186,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data is not None
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape

View file

@ -1,3 +1,5 @@
# pyright: reportUnusedImport=false
from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
from .gguf_dump import main as gguf_dump_entrypoint
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint

View file

@ -1,4 +1,6 @@
#!/usr/bin/env python3
from __future__ import annotations
import logging
import argparse
import os

View file

@ -1,4 +1,4 @@
import gguf # noqa: F401
import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
# TODO: add tests