Update llava-surgery-v2.py
This commit is contained in:
parent
440b2ae2b1
commit
35b7a7a183
1 changed files with 34 additions and 13 deletions
|
@ -51,7 +51,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
|||
existing_clip = {}
|
||||
# Update existing_clip with new tensors, avoid duplicates
|
||||
for name in clip_tensors:
|
||||
simple_name = name.replace("vision_tower.vision_tower.", "")
|
||||
simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
|
||||
print(f"Adding {simple_name} to llava.clip")
|
||||
if simple_name not in existing_clip:
|
||||
existing_clip[simple_name] = checkpoint[name]
|
||||
|
@ -69,6 +69,25 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
|||
return True
|
||||
return False
|
||||
|
||||
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||
newline_checkpoint_path = None
|
||||
projector_checkpoint_path = None
|
||||
|
||||
for path in checkpoint_paths:
|
||||
checkpoint, _ = load_model(path)
|
||||
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
|
||||
newline_checkpoint_path = path
|
||||
if projector(checkpoint):
|
||||
projector_checkpoint_path = path
|
||||
|
||||
return newline_checkpoint_path, projector_checkpoint_path
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
||||
|
||||
|
||||
# Command-line interface setup
|
||||
ap = argparse.ArgumentParser()
|
||||
|
@ -81,25 +100,27 @@ if args.clean_vision_tower:
|
|||
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
|
||||
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
|
||||
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
|
||||
for last_checkpoint_path in checkpoint_paths:
|
||||
print(f"Cleaning {last_checkpoint_path}")
|
||||
if not clean_vision_tower_from_checkpoint(last_checkpoint_path):
|
||||
print(f"No vision tower found in {last_checkpoint_path}")
|
||||
for projector_checkpoint_path in checkpoint_paths:
|
||||
print(f"Cleaning {projector_checkpoint_path}")
|
||||
if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
|
||||
print(f"No vision tower found in {projector_checkpoint_path}")
|
||||
# we break once none is found, so far all models append them at the end
|
||||
break
|
||||
# break
|
||||
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
|
||||
|
||||
# Now we look for the projector in the last checkpoint
|
||||
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
|
||||
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
|
||||
last_checkpoint_path = checkpoint_paths[0]
|
||||
first_checkpoint_path = checkpoint_paths[-1]
|
||||
# last_checkpoint_path = checkpoint_paths[0]
|
||||
# first_checkpoint_path = checkpoint_paths[-1]
|
||||
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
|
||||
|
||||
print(f"Taking projector from {last_checkpoint_path}")
|
||||
print(f"Taking projector from {projector_checkpoint_path}")
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
|
||||
# Load the checkpoint
|
||||
first_checkpoint, file_type = load_model(first_checkpoint_path)
|
||||
last_checkpoint, file_type = load_model(last_checkpoint_path)
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
|
||||
|
@ -129,9 +150,9 @@ for name in first_mm_tensors:
|
|||
del first_checkpoint[name]
|
||||
|
||||
if len(mm_tensors) > 0:
|
||||
save_model(last_checkpoint, last_checkpoint_path, file_type)
|
||||
save_model(last_checkpoint, projector_checkpoint_path, file_type)
|
||||
if len(first_mm_tensors) > 0:
|
||||
save_model(first_checkpoint, first_checkpoint_path, file_type)
|
||||
save_model(first_checkpoint, newline_checkpoint_path, file_type)
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue