bugfix for non llava-1.6
It should now work with llava-1.5 as well
This commit is contained in:
parent
c92431a0a4
commit
c9874dd0d6
1 changed files with 19 additions and 11 deletions
|
@ -38,7 +38,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
# file_type = 'pytorch'
|
# file_type = 'pytorch'
|
||||||
model_path = os.path.dirname(checkpoint_path)
|
model_path = os.path.dirname(checkpoint_path)
|
||||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") ) ]
|
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
||||||
|
|
||||||
if len(clip_tensors) > 0:
|
if len(clip_tensors) > 0:
|
||||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||||
|
@ -46,8 +46,10 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
clip_path = os.path.join(model_path, "llava.clip")
|
clip_path = os.path.join(model_path, "llava.clip")
|
||||||
|
|
||||||
if os.path.exists(clip_path):
|
if os.path.exists(clip_path):
|
||||||
|
print(f"Loading existing llava.clip from {clip_path}")
|
||||||
existing_clip, _ = load_model(clip_path)
|
existing_clip, _ = load_model(clip_path)
|
||||||
else:
|
else:
|
||||||
|
print(f"Creating new llava.clip at {clip_path}")
|
||||||
existing_clip = {}
|
existing_clip = {}
|
||||||
# Update existing_clip with new tensors, avoid duplicates
|
# Update existing_clip with new tensors, avoid duplicates
|
||||||
for name in clip_tensors:
|
for name in clip_tensors:
|
||||||
|
@ -116,17 +118,22 @@ checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and '
|
||||||
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
|
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
|
||||||
|
|
||||||
print(f"Taking projector from {projector_checkpoint_path}")
|
print(f"Taking projector from {projector_checkpoint_path}")
|
||||||
|
first_mm_tensors = []
|
||||||
|
first_checkpoint = None
|
||||||
|
if newline_checkpoint_path is not None:
|
||||||
print(f"Taking newline from {newline_checkpoint_path}")
|
print(f"Taking newline from {newline_checkpoint_path}")
|
||||||
|
|
||||||
# Load the checkpoint
|
|
||||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
|
||||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
|
||||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||||
|
|
||||||
|
# Load the checkpoint
|
||||||
|
mm_tensors = []
|
||||||
|
last_checkpoint = None
|
||||||
|
if projector_checkpoint_path is not None:
|
||||||
|
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||||
|
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||||
|
|
||||||
if len(mm_tensors) == 0:
|
if len(mm_tensors) == 0:
|
||||||
|
if last_checkpoint is not None:
|
||||||
for k, v in last_checkpoint.items():
|
for k, v in last_checkpoint.items():
|
||||||
print(k)
|
print(k)
|
||||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||||
|
@ -142,6 +149,7 @@ for name in mm_tensors:
|
||||||
for name in first_mm_tensors:
|
for name in first_mm_tensors:
|
||||||
projector[name] = first_checkpoint[name].float()
|
projector[name] = first_checkpoint[name].float()
|
||||||
|
|
||||||
|
if len(projector) > 0:
|
||||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||||
|
|
||||||
for name in mm_tensors:
|
for name in mm_tensors:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue