Update llava-surgery-v2.py

This commit is contained in:
John 2024-02-02 02:07:42 +01:00 committed by GitHub
parent 440b2ae2b1
commit 35b7a7a183
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -51,7 +51,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
existing_clip = {} existing_clip = {}
# Update existing_clip with new tensors, avoid duplicates # Update existing_clip with new tensors, avoid duplicates
for name in clip_tensors: for name in clip_tensors:
simple_name = name.replace("vision_tower.vision_tower.", "") simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
print(f"Adding {simple_name} to llava.clip") print(f"Adding {simple_name} to llava.clip")
if simple_name not in existing_clip: if simple_name not in existing_clip:
existing_clip[simple_name] = checkpoint[name] existing_clip[simple_name] = checkpoint[name]
@ -69,6 +69,25 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
return True return True
return False return False
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
newline_checkpoint_path = None
projector_checkpoint_path = None
for path in checkpoint_paths:
checkpoint, _ = load_model(path)
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
newline_checkpoint_path = path
if projector(checkpoint):
projector_checkpoint_path = path
return newline_checkpoint_path, projector_checkpoint_path
def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
# Command-line interface setup # Command-line interface setup
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
@ -81,25 +100,27 @@ if args.clean_vision_tower:
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))] # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
for last_checkpoint_path in checkpoint_paths: for projector_checkpoint_path in checkpoint_paths:
print(f"Cleaning {last_checkpoint_path}") print(f"Cleaning {projector_checkpoint_path}")
if not clean_vision_tower_from_checkpoint(last_checkpoint_path): if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
print(f"No vision tower found in {last_checkpoint_path}") print(f"No vision tower found in {projector_checkpoint_path}")
# we break once none is found, so far all models append them at the end # we break once none is found, so far all models append them at the end
break # break
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.") print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
# Now we look for the projector in the last checkpoint # Now we look for the projector in the last checkpoint
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
last_checkpoint_path = checkpoint_paths[0] # last_checkpoint_path = checkpoint_paths[0]
first_checkpoint_path = checkpoint_paths[-1] # first_checkpoint_path = checkpoint_paths[-1]
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
print(f"Taking projector from {last_checkpoint_path}") print(f"Taking projector from {projector_checkpoint_path}")
print(f"Taking newline from {newline_checkpoint_path}")
# Load the checkpoint # Load the checkpoint
first_checkpoint, file_type = load_model(first_checkpoint_path) first_checkpoint, file_type = load_model(newline_checkpoint_path)
last_checkpoint, file_type = load_model(last_checkpoint_path) last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
@ -129,9 +150,9 @@ for name in first_mm_tensors:
del first_checkpoint[name] del first_checkpoint[name]
if len(mm_tensors) > 0: if len(mm_tensors) > 0:
save_model(last_checkpoint, last_checkpoint_path, file_type) save_model(last_checkpoint, projector_checkpoint_path, file_type)
if len(first_mm_tensors) > 0: if len(first_mm_tensors) > 0:
save_model(first_checkpoint, first_checkpoint_path, file_type) save_model(first_checkpoint, newline_checkpoint_path, file_type)
print("Done!") print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")