py : type-check all Python scripts with Pyright

This commit is contained in:
Francis Couture-Harpin 2024-07-06 11:36:28 -04:00
parent 87e25a1d1b
commit e29fd9634c
35 changed files with 264 additions and 136 deletions

View file

@ -2,7 +2,9 @@ import argparse
import glob
import os
import torch
from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
from safetensors import safe_open
from safetensors.torch import save_file
from typing import Any, ContextManager, cast
# Function to determine if file is a SafeTensor file
def is_safetensor_file(file_path):
@ -13,7 +15,7 @@ def is_safetensor_file(file_path):
def load_model(file_path):
if is_safetensor_file(file_path):
tensors = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key).clone()
# output shape
@ -134,7 +136,7 @@ if len(mm_tensors) == 0:
if last_checkpoint is not None:
for k, v in last_checkpoint.items():
print(k)
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
print("No tensors found. Is this a LLaVA model?")
exit()
@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
projector = {}
for name in mm_tensors:
assert last_checkpoint is not None
projector[name] = last_checkpoint[name].float()
for name in first_mm_tensors:
assert first_checkpoint is not None
projector[name] = first_checkpoint[name].float()
if len(projector) > 0: