From c9874dd0d65c8e0d42588f287af02d8905999e21 Mon Sep 17 00:00:00 2001 From: John <78893154+cmp-nct@users.noreply.github.com> Date: Wed, 14 Feb 2024 05:05:57 +0100 Subject: [PATCH] bugfix for non llava-1.6 It should now work with llava-1.5 as well --- examples/llava/llava-surgery-v2.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/examples/llava/llava-surgery-v2.py b/examples/llava/llava-surgery-v2.py index f0ade4ceb..5bc5bc513 100644 --- a/examples/llava/llava-surgery-v2.py +++ b/examples/llava/llava-surgery-v2.py @@ -38,7 +38,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") ) ] + clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -46,8 +46,10 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): clip_path = os.path.join(model_path, "llava.clip") if os.path.exists(clip_path): + print(f"Loading existing llava.clip from {clip_path}") existing_clip, _ = load_model(clip_path) else: + print(f"Creating new llava.clip at {clip_path}") existing_clip = {} # Update existing_clip with new tensors, avoid duplicates for name in clip_tensors: @@ -116,19 +118,24 @@ checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and ' newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria) print(f"Taking projector from {projector_checkpoint_path}") -print(f"Taking newline from {newline_checkpoint_path}") +first_mm_tensors = [] +first_checkpoint = None +if newline_checkpoint_path is not None: + print(f"Taking newline from {newline_checkpoint_path}") + first_checkpoint, file_type = load_model(newline_checkpoint_path) + first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] # Load the checkpoint -first_checkpoint, file_type = load_model(newline_checkpoint_path) -last_checkpoint, file_type = load_model(projector_checkpoint_path) -mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] -first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] - - +mm_tensors = [] +last_checkpoint = None +if projector_checkpoint_path is not None: + last_checkpoint, file_type = load_model(projector_checkpoint_path) + mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] if len(mm_tensors) == 0: - for k, v in last_checkpoint.items(): - print(k) + if last_checkpoint is not None: + for k, v in last_checkpoint.items(): + print(k) print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.") print("No tensors found. Is this a LLaVA model?") exit() @@ -142,7 +149,8 @@ for name in mm_tensors: for name in first_mm_tensors: projector[name] = first_checkpoint[name].float() -save_model(projector, f"{args.model}/llava.projector", 'pytorch') +if len(projector) > 0: + save_model(projector, f"{args.model}/llava.projector", 'pytorch') for name in mm_tensors: del last_checkpoint[name]