add xgen-mm surgery
This commit is contained in:
parent
b841d07408
commit
433f7aa287
1 changed files with 99 additions and 0 deletions
99
examples/xgenmm/xgenmm-surgery.py
Normal file
99
examples/xgenmm/xgenmm-surgery.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import torch
|
||||
import argparse
|
||||
from open_flamingo import create_model_and_transforms
|
||||
from omegaconf import OmegaConf
|
||||
import os
|
||||
import time
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_pth", type=str, default='/export/share/manli_shu/models/open-flamingo-dev/anyres_ablation_HFSiglip_patch128-kosmos_non_instruct-phi3_4k_instruct_nq128_pre_V3_5-llava_1p6_ocrmathmix_v4-8x8-ckpt2/checkpoint_0.pt')
|
||||
parser.add_argument('--save_pth', type=str, default='/export/share/yutong/xgenmm/llamacpp_wd')
|
||||
parser.add_argument('--version', type=str, default='siglip_kosmos_phi3_4k_instruct', help='help identify the version of the saved ckpt')
|
||||
return parser.parse_args()
|
||||
|
||||
VISION_ENCODER_KEY = 'vision_encoder'
|
||||
LLM_KEY = 'lang_model'
|
||||
PROJECTOR = 'vision_tokenizer'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load ckpt
|
||||
args = get_args()
|
||||
print("🟡 Loading ckpt...")
|
||||
start = time.time()
|
||||
ckpt = torch.load(args.ckpt_pth)["model_state_dict"]
|
||||
end = time.time()
|
||||
print(f"🟢 time used: [{end-start:.3f} s] | Done with loading ckpt")
|
||||
|
||||
# sanity check
|
||||
unexpected_component_keys = set()
|
||||
for k in list(ckpt.keys()):
|
||||
matched = False
|
||||
for c in ['vision_encoder', 'lang_model', 'vision_tokenizer']:
|
||||
if k.startswith(c):
|
||||
matched = True
|
||||
continue
|
||||
if not matched:
|
||||
unexpected_component_keys.add(k)
|
||||
|
||||
if len(unexpected_component_keys) > 0:
|
||||
print(f"❗❗❗ Unexpected component keys: {unexpected_component_keys}. Proceed with caution.")
|
||||
|
||||
save_dir = f"{args.save_pth}/{args.version}"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# get a list vl connector keys
|
||||
projector_tensors = {v.float(): v for k, v in ckpt.items() if k.startswith(PROJECTOR)}
|
||||
print("🟡 Saving project ckpt...")
|
||||
save_path = f"{save_dir}/xgenmm.projector"
|
||||
start = time.time()
|
||||
torch.save(projector_tensors, save_path)
|
||||
end = time.time()
|
||||
print(f"🟢 time used: [{end-start:.3f} s] | Save projector ckpt at: {save_path}")
|
||||
|
||||
# here we use the siglip
|
||||
vision_encoder_tensors = {v.float(): v for k, v in ckpt.items() if k.startswith(VISION_ENCODER_KEY)}
|
||||
print("🟡 Saving vision encoder ckpt...")
|
||||
save_path = f"{save_dir}/xgenmm.clip"
|
||||
start = time.time()
|
||||
torch.save(vision_encoder_tensors, save_path)
|
||||
end = time.time()
|
||||
print(f"🟢 time used: [{end-start:.3f} s] | Save projector ckpt at: {save_path}")
|
||||
|
||||
|
||||
# hard code to load the model using open-flamingo
|
||||
print("🟡 Saving llm ckpt...")
|
||||
cfg = dict(
|
||||
model_family = 'kosmos',
|
||||
lm_path = 'microsoft/Phi-3-mini-4k-instruct',
|
||||
vision_encoder_path = 'google/siglip-so400m-patch14-384',
|
||||
vision_encoder_pretrained = 'google',
|
||||
num_vision_tokens = 128,
|
||||
image_aspect_ratio = 'anyres',
|
||||
anyres_patch_sampling = True,
|
||||
anyres_grids=[[1,2],[2,1],[2,2],[3,1],[1,3]],
|
||||
ckpt_pth = args.ckpt_pth)
|
||||
cfg = OmegaConf.create(cfg)
|
||||
if cfg.model_family in ['kosmos-instruct', 'kosmos', 'llava']:
|
||||
additional_kwargs = {
|
||||
"image_aspect_ratio": cfg.image_aspect_ratio,
|
||||
}
|
||||
if cfg.model_family in ['kosmos-instruct', 'kosmos']:
|
||||
additional_kwargs.update({
|
||||
"num_vision_tokens": cfg.num_vision_tokens,
|
||||
"anyres_patch_sampling": cfg.anyres_patch_sampling,
|
||||
})
|
||||
model, image_processor, tokenizer = create_model_and_transforms(
|
||||
clip_vision_encoder_path=cfg.vision_encoder_path,
|
||||
clip_vision_encoder_pretrained=cfg.vision_encoder_pretrained,
|
||||
lang_model_path=cfg.lm_path,
|
||||
tokenizer_path=cfg.lm_path,
|
||||
model_family=cfg.model_family,
|
||||
**additional_kwargs)
|
||||
model.load_state_dict(ckpt, strict=True)
|
||||
start = time.time()
|
||||
llm = model.lang_model.save_pretrained(f"{save_dir}/model")
|
||||
tokenizer.save_pretrained(f"{save_dir}/model")
|
||||
end = time.time()
|
||||
print(f"🟢 time used: [{end-start:.3f} s] | Save projector ckpt at: {save_dir}/model")
|
Loading…
Add table
Add a link
Reference in a new issue