py : type-check all Python scripts with Pyright
This commit is contained in:
parent
87e25a1d1b
commit
e29fd9634c
35 changed files with 264 additions and 136 deletions
|
@ -353,7 +353,7 @@ class Metadata:
|
|||
version: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
licence: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_hf_repo: Optional[str] = None
|
||||
|
||||
|
@ -492,12 +492,13 @@ class LazyTensor:
|
|||
|
||||
LazyModel: TypeAlias = 'dict[str, LazyTensor]'
|
||||
|
||||
ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
|
||||
@dataclass
|
||||
class ModelPlus:
|
||||
model: LazyModel
|
||||
paths: list[Path] # Where this was read from.
|
||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
format: ModelFormat
|
||||
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
|
||||
|
||||
|
||||
|
@ -536,7 +537,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|||
|
||||
|
||||
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
||||
formats = set(mp.format for mp in models_plus)
|
||||
formats: set[ModelFormat] = set(mp.format for mp in models_plus)
|
||||
assert len(formats) == 1, "different formats?"
|
||||
format = formats.pop()
|
||||
paths = [path for mp in models_plus for path in mp.paths]
|
||||
|
@ -555,7 +556,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
|||
else:
|
||||
model = merge_sharded([mp.model for mp in models_plus])
|
||||
|
||||
return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types
|
||||
return ModelPlus(model, paths, format, vocab)
|
||||
|
||||
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
|
||||
|
@ -805,7 +806,7 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||
|
||||
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
|
||||
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
|
||||
# Metadata About The Model And Its Provenence
|
||||
name = "LLaMA"
|
||||
if metadata is not None and metadata.name is not None:
|
||||
|
@ -827,8 +828,8 @@ class OutputFile:
|
|||
self.gguf.add_url(metadata.url)
|
||||
if metadata.description is not None:
|
||||
self.gguf.add_description(metadata.description)
|
||||
if metadata.licence is not None:
|
||||
self.gguf.add_licence(metadata.licence)
|
||||
if metadata.license is not None:
|
||||
self.gguf.add_licence(metadata.license)
|
||||
if metadata.source_url is not None:
|
||||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_hf_repo is not None:
|
||||
|
@ -943,7 +944,7 @@ class OutputFile:
|
|||
@staticmethod
|
||||
def write_vocab_only(
|
||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -977,7 +978,7 @@ class OutputFile:
|
|||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||
pad_vocab: bool = False,
|
||||
metadata: Metadata = None,
|
||||
metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -1396,6 +1397,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||
vocab = model_plus.vocab
|
||||
|
||||
assert params is not None
|
||||
|
||||
logger.info(f"Vocab info: {vocab}")
|
||||
logger.info(f"Special vocab info: {special_vocab}")
|
||||
model = model_plus.model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue