Update llava-surgery-v2.py

This commit is contained in:
John 2024-02-02 02:07:42 +01:00 committed by GitHub
parent 440b2ae2b1
commit 35b7a7a183
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -51,7 +51,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
existing_clip = {}
# Update existing_clip with new tensors, avoid duplicates
for name in clip_tensors:
simple_name = name.replace("vision_tower.vision_tower.", "")
simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
print(f"Adding {simple_name} to llava.clip")
if simple_name not in existing_clip:
existing_clip[simple_name] = checkpoint[name]
@ -69,6 +69,25 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
return True
return False
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
newline_checkpoint_path = None
projector_checkpoint_path = None
for path in checkpoint_paths:
checkpoint, _ = load_model(path)
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
newline_checkpoint_path = path
if projector(checkpoint):
projector_checkpoint_path = path
return newline_checkpoint_path, projector_checkpoint_path
def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
# Command-line interface setup
ap = argparse.ArgumentParser()
@ -81,25 +100,27 @@ if args.clean_vision_tower:
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
for last_checkpoint_path in checkpoint_paths:
print(f"Cleaning {last_checkpoint_path}")
if not clean_vision_tower_from_checkpoint(last_checkpoint_path):
print(f"No vision tower found in {last_checkpoint_path}")
for projector_checkpoint_path in checkpoint_paths:
print(f"Cleaning {projector_checkpoint_path}")
if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
print(f"No vision tower found in {projector_checkpoint_path}")
# we break once none is found, so far all models append them at the end
break
# break
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
# Now we look for the projector in the last checkpoint
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
last_checkpoint_path = checkpoint_paths[0]
first_checkpoint_path = checkpoint_paths[-1]
# last_checkpoint_path = checkpoint_paths[0]
# first_checkpoint_path = checkpoint_paths[-1]
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
print(f"Taking projector from {last_checkpoint_path}")
print(f"Taking projector from {projector_checkpoint_path}")
print(f"Taking newline from {newline_checkpoint_path}")
# Load the checkpoint
first_checkpoint, file_type = load_model(first_checkpoint_path)
last_checkpoint, file_type = load_model(last_checkpoint_path)
first_checkpoint, file_type = load_model(newline_checkpoint_path)
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
@ -129,9 +150,9 @@ for name in first_mm_tensors:
del first_checkpoint[name]
if len(mm_tensors) > 0:
save_model(last_checkpoint, last_checkpoint_path, file_type)
save_model(last_checkpoint, projector_checkpoint_path, file_type)
if len(first_mm_tensors) > 0:
save_model(first_checkpoint, first_checkpoint_path, file_type)
save_model(first_checkpoint, newline_checkpoint_path, file_type)
print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")