bugfix for non llava-1.6

It should now work with llava-1.5 as well
This commit is contained in:
John 2024-02-14 05:05:57 +01:00 committed by GitHub
parent c92431a0a4
commit c9874dd0d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -38,7 +38,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
# file_type = 'pytorch'
model_path = os.path.dirname(checkpoint_path)
print(f"Searching for vision tower tensors in {checkpoint_path}")
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") ) ]
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
if len(clip_tensors) > 0:
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
@ -46,8 +46,10 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
clip_path = os.path.join(model_path, "llava.clip")
if os.path.exists(clip_path):
print(f"Loading existing llava.clip from {clip_path}")
existing_clip, _ = load_model(clip_path)
else:
print(f"Creating new llava.clip at {clip_path}")
existing_clip = {}
# Update existing_clip with new tensors, avoid duplicates
for name in clip_tensors:
@ -116,19 +118,24 @@ checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and '
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
print(f"Taking projector from {projector_checkpoint_path}")
print(f"Taking newline from {newline_checkpoint_path}")
first_mm_tensors = []
first_checkpoint = None
if newline_checkpoint_path is not None:
print(f"Taking newline from {newline_checkpoint_path}")
first_checkpoint, file_type = load_model(newline_checkpoint_path)
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
# Load the checkpoint
first_checkpoint, file_type = load_model(newline_checkpoint_path)
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
mm_tensors = []
last_checkpoint = None
if projector_checkpoint_path is not None:
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
if len(mm_tensors) == 0:
for k, v in last_checkpoint.items():
print(k)
if last_checkpoint is not None:
for k, v in last_checkpoint.items():
print(k)
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
print("No tensors found. Is this a LLaVA model?")
exit()
@ -142,7 +149,8 @@ for name in mm_tensors:
for name in first_mm_tensors:
projector[name] = first_checkpoint[name].float()
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
if len(projector) > 0:
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
for name in mm_tensors:
del last_checkpoint[name]