py : type-check all Python scripts with Pyright
This commit is contained in:
parent
87e25a1d1b
commit
e29fd9634c
35 changed files with 264 additions and 136 deletions
|
@ -185,6 +185,8 @@ else:
|
|||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
|
@ -259,8 +261,8 @@ if has_vision_encoder:
|
|||
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||
|
@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
|||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
@ -286,7 +288,7 @@ if has_llava_projector:
|
|||
|
||||
print("Projector tensors added\n")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
|
||||
for name, data in state_dict.items():
|
||||
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
|
||||
# we don't need this
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue