diff --git a/examples/llava/llava-surgery-v2.py b/examples/llava/llava-surgery-v2.py index 51f9cb638..a5850b96e 100644 --- a/examples/llava/llava-surgery-v2.py +++ b/examples/llava/llava-surgery-v2.py @@ -51,7 +51,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): existing_clip = {} # Update existing_clip with new tensors, avoid duplicates for name in clip_tensors: - simple_name = name.replace("vision_tower.vision_tower.", "") + simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name print(f"Adding {simple_name} to llava.clip") if simple_name not in existing_clip: existing_clip[simple_name] = checkpoint[name] @@ -69,6 +69,25 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): return True return False +def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): + newline_checkpoint_path = None + projector_checkpoint_path = None + + for path in checkpoint_paths: + checkpoint, _ = load_model(path) + if newline_criteria(checkpoint) and newline_checkpoint_path is None: + newline_checkpoint_path = path + if projector(checkpoint): + projector_checkpoint_path = path + + return newline_checkpoint_path, projector_checkpoint_path + +def newline_criteria(checkpoint): + return any(k.startswith("model.image_newline") for k in checkpoint.keys()) + +def proj_criteria(checkpoint): + return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + # Command-line interface setup ap = argparse.ArgumentParser() @@ -81,25 +100,27 @@ if args.clean_vision_tower: model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))] checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] - for last_checkpoint_path in checkpoint_paths: - print(f"Cleaning {last_checkpoint_path}") - if not clean_vision_tower_from_checkpoint(last_checkpoint_path): - print(f"No vision tower found in {last_checkpoint_path}") + for projector_checkpoint_path in checkpoint_paths: + print(f"Cleaning {projector_checkpoint_path}") + if not clean_vision_tower_from_checkpoint(projector_checkpoint_path): + print(f"No vision tower found in {projector_checkpoint_path}") # we break once none is found, so far all models append them at the end - break + # break print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.") # Now we look for the projector in the last checkpoint model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] -last_checkpoint_path = checkpoint_paths[0] -first_checkpoint_path = checkpoint_paths[-1] +# last_checkpoint_path = checkpoint_paths[0] +# first_checkpoint_path = checkpoint_paths[-1] +newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria) -print(f"Taking projector from {last_checkpoint_path}") +print(f"Taking projector from {projector_checkpoint_path}") +print(f"Taking newline from {newline_checkpoint_path}") # Load the checkpoint -first_checkpoint, file_type = load_model(first_checkpoint_path) -last_checkpoint, file_type = load_model(last_checkpoint_path) +first_checkpoint, file_type = load_model(newline_checkpoint_path) +last_checkpoint, file_type = load_model(projector_checkpoint_path) mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] @@ -129,9 +150,9 @@ for name in first_mm_tensors: del first_checkpoint[name] if len(mm_tensors) > 0: - save_model(last_checkpoint, last_checkpoint_path, file_type) + save_model(last_checkpoint, projector_checkpoint_path, file_type) if len(first_mm_tensors) > 0: - save_model(first_checkpoint, first_checkpoint_path, file_type) + save_model(first_checkpoint, newline_checkpoint_path, file_type) print("Done!") print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")