From cc1c135367937261e3b4823476c596df37e25d40 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 22 Jan 2025 01:43:49 -0700 Subject: [PATCH] Clean up llava surgery and remove name substitution hacks Signed-off-by: Alex-Brooks --- examples/llava/llava_surgery_v2.py | 53 ++++++++++++++++++------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index 5119c9ccc..b07c3e323 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -33,6 +33,33 @@ def save_model(model, file_path, file_type): else: torch.save(model, file_path) +# Helpers to match weight names from specific components or +# determine if a saved shard contains that component +def is_vision_tower(weight_name): + return ( + weight_name.startswith("model.vision_tower") or + weight_name.startswith("vit.") or + weight_name.startswith("vision_tower") + ) + +def is_newline(weight_name): + return ( + weight_name.startswith("model.image_newline") or + weight_name.startswith("image_newline") + ) + +def is_mm_projector(weight_name): + return ( + weight_name.startswith("model.mm_projector") or + weight_name.startswith("vision_proj.") or + weight_name.startswith("multi_modal_projector") + ) + +def newline_criteria(checkpoint): + return any(is_newline(k) for k in checkpoint.keys()) + +def proj_criteria(checkpoint): + return any(is_mm_projector(k) for k in checkpoint.keys()) # Adapted function to clean vision tower from checkpoint def clean_vision_tower_from_checkpoint(checkpoint_path): @@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))] + clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path -def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys()) - -def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys()) - # Command-line interface setup ap = argparse.ArgumentParser() @@ -123,14 +144,14 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")] + mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -146,20 +167,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: assert last_checkpoint is not None - # HACK - this should probably be in the second script... - new_name = name - if new_name.startswith("multi_modal_projector.linear_1"): - new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0") - elif new_name.startswith("multi_modal_projector.linear_2"): - new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2") - projector[new_name] = last_checkpoint[name].float() + projector[name] = last_checkpoint[name].float() for name in first_mm_tensors: assert first_checkpoint is not None - # HACK - this should probably be in the second script too... - new_name = name - if new_name == "image_newline": - new_name = "model.image_newline" - projector[new_name] = first_checkpoint[name].float() + projector[name] = first_checkpoint[name].float() if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch')