Clean up llava surgery and remove name substitution hacks
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
92046a103d
commit
cc1c135367
1 changed files with 32 additions and 21 deletions
|
@ -33,6 +33,33 @@ def save_model(model, file_path, file_type):
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_path)
|
torch.save(model, file_path)
|
||||||
|
|
||||||
|
# Helpers to match weight names from specific components or
|
||||||
|
# determine if a saved shard contains that component
|
||||||
|
def is_vision_tower(weight_name):
|
||||||
|
return (
|
||||||
|
weight_name.startswith("model.vision_tower") or
|
||||||
|
weight_name.startswith("vit.") or
|
||||||
|
weight_name.startswith("vision_tower")
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_newline(weight_name):
|
||||||
|
return (
|
||||||
|
weight_name.startswith("model.image_newline") or
|
||||||
|
weight_name.startswith("image_newline")
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_mm_projector(weight_name):
|
||||||
|
return (
|
||||||
|
weight_name.startswith("model.mm_projector") or
|
||||||
|
weight_name.startswith("vision_proj.") or
|
||||||
|
weight_name.startswith("multi_modal_projector")
|
||||||
|
)
|
||||||
|
|
||||||
|
def newline_criteria(checkpoint):
|
||||||
|
return any(is_newline(k) for k in checkpoint.keys())
|
||||||
|
|
||||||
|
def proj_criteria(checkpoint):
|
||||||
|
return any(is_mm_projector(k) for k in checkpoint.keys())
|
||||||
|
|
||||||
# Adapted function to clean vision tower from checkpoint
|
# Adapted function to clean vision tower from checkpoint
|
||||||
def clean_vision_tower_from_checkpoint(checkpoint_path):
|
def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
|
@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
# file_type = 'pytorch'
|
# file_type = 'pytorch'
|
||||||
model_path = os.path.dirname(checkpoint_path)
|
model_path = os.path.dirname(checkpoint_path)
|
||||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))]
|
clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
|
||||||
|
|
||||||
if len(clip_tensors) > 0:
|
if len(clip_tensors) > 0:
|
||||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||||
|
@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||||
|
|
||||||
return newline_checkpoint_path, projector_checkpoint_path
|
return newline_checkpoint_path, projector_checkpoint_path
|
||||||
|
|
||||||
def newline_criteria(checkpoint):
|
|
||||||
return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys())
|
|
||||||
|
|
||||||
def proj_criteria(checkpoint):
|
|
||||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# Command-line interface setup
|
# Command-line interface setup
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
@ -123,14 +144,14 @@ first_checkpoint = None
|
||||||
if newline_checkpoint_path is not None:
|
if newline_checkpoint_path is not None:
|
||||||
print(f"Taking newline from {newline_checkpoint_path}")
|
print(f"Taking newline from {newline_checkpoint_path}")
|
||||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")]
|
first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
|
||||||
|
|
||||||
# Load the checkpoint
|
# Load the checkpoint
|
||||||
mm_tensors = []
|
mm_tensors = []
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
if projector_checkpoint_path is not None:
|
if projector_checkpoint_path is not None:
|
||||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")]
|
mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
|
||||||
|
|
||||||
if len(mm_tensors) == 0:
|
if len(mm_tensors) == 0:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
|
@ -146,20 +167,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
||||||
projector = {}
|
projector = {}
|
||||||
for name in mm_tensors:
|
for name in mm_tensors:
|
||||||
assert last_checkpoint is not None
|
assert last_checkpoint is not None
|
||||||
# HACK - this should probably be in the second script...
|
projector[name] = last_checkpoint[name].float()
|
||||||
new_name = name
|
|
||||||
if new_name.startswith("multi_modal_projector.linear_1"):
|
|
||||||
new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0")
|
|
||||||
elif new_name.startswith("multi_modal_projector.linear_2"):
|
|
||||||
new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2")
|
|
||||||
projector[new_name] = last_checkpoint[name].float()
|
|
||||||
for name in first_mm_tensors:
|
for name in first_mm_tensors:
|
||||||
assert first_checkpoint is not None
|
assert first_checkpoint is not None
|
||||||
# HACK - this should probably be in the second script too...
|
projector[name] = first_checkpoint[name].float()
|
||||||
new_name = name
|
|
||||||
if new_name == "image_newline":
|
|
||||||
new_name = "model.image_newline"
|
|
||||||
projector[new_name] = first_checkpoint[name].float()
|
|
||||||
|
|
||||||
if len(projector) > 0:
|
if len(projector) > 0:
|
||||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue