py : type-check all Python scripts with Pyright
This commit is contained in:
parent
87e25a1d1b
commit
e29fd9634c
35 changed files with 264 additions and 136 deletions
|
@ -185,6 +185,8 @@ else:
|
|||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
|
@ -259,8 +261,8 @@ if has_vision_encoder:
|
|||
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||
|
@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
|||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
@ -286,7 +288,7 @@ if has_llava_projector:
|
|||
|
||||
print("Projector tensors added\n")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
|
||||
for name, data in state_dict.items():
|
||||
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
|
||||
# we don't need this
|
||||
|
|
|
@ -2,7 +2,9 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from typing import Any, ContextManager, cast
|
||||
|
||||
# Function to determine if file is a SafeTensor file
|
||||
def is_safetensor_file(file_path):
|
||||
|
@ -13,7 +15,7 @@ def is_safetensor_file(file_path):
|
|||
def load_model(file_path):
|
||||
if is_safetensor_file(file_path):
|
||||
tensors = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key).clone()
|
||||
# output shape
|
||||
|
@ -134,7 +136,7 @@ if len(mm_tensors) == 0:
|
|||
if last_checkpoint is not None:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
|
||||
print("No tensors found. Is this a LLaVA model?")
|
||||
exit()
|
||||
|
||||
|
@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
|||
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
|
||||
projector = {}
|
||||
for name in mm_tensors:
|
||||
assert last_checkpoint is not None
|
||||
projector[name] = last_checkpoint[name].float()
|
||||
for name in first_mm_tensors:
|
||||
assert first_checkpoint is not None
|
||||
projector[name] = first_checkpoint[name].float()
|
||||
|
||||
if len(projector) > 0:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
-r ../../requirements/requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pillow~=10.2.0
|
||||
torch~=2.2.1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue