bugfix for non llava-1.6
It should now work with llava-1.5 as well
This commit is contained in:
parent
c92431a0a4
commit
c9874dd0d6
1 changed files with 19 additions and 11 deletions
|
@ -38,7 +38,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
|||
# file_type = 'pytorch'
|
||||
model_path = os.path.dirname(checkpoint_path)
|
||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") ) ]
|
||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
||||
|
||||
if len(clip_tensors) > 0:
|
||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||
|
@ -46,8 +46,10 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
|||
clip_path = os.path.join(model_path, "llava.clip")
|
||||
|
||||
if os.path.exists(clip_path):
|
||||
print(f"Loading existing llava.clip from {clip_path}")
|
||||
existing_clip, _ = load_model(clip_path)
|
||||
else:
|
||||
print(f"Creating new llava.clip at {clip_path}")
|
||||
existing_clip = {}
|
||||
# Update existing_clip with new tensors, avoid duplicates
|
||||
for name in clip_tensors:
|
||||
|
@ -116,19 +118,24 @@ checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and '
|
|||
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
|
||||
|
||||
print(f"Taking projector from {projector_checkpoint_path}")
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
first_mm_tensors = []
|
||||
first_checkpoint = None
|
||||
if newline_checkpoint_path is not None:
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
|
||||
# Load the checkpoint
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
|
||||
|
||||
mm_tensors = []
|
||||
last_checkpoint = None
|
||||
if projector_checkpoint_path is not None:
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
|
||||
if len(mm_tensors) == 0:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
if last_checkpoint is not None:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||
print("No tensors found. Is this a LLaVA model?")
|
||||
exit()
|
||||
|
@ -142,7 +149,8 @@ for name in mm_tensors:
|
|||
for name in first_mm_tensors:
|
||||
projector[name] = first_checkpoint[name].float()
|
||||
|
||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||
if len(projector) > 0:
|
||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||
|
||||
for name in mm_tensors:
|
||||
del last_checkpoint[name]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue