From 40340d82af4ca1fe9ca519b0e7cc42dff427406f Mon Sep 17 00:00:00 2001 From: ningshanwutuobang Date: Sun, 25 Jun 2023 23:42:26 +0800 Subject: [PATCH] Add MiniGPT-4 example --- .gitignore | 2 +- examples/embd-input/.gitignore | 4 + examples/embd-input/README.md | 36 ++++++--- examples/embd-input/minigpt4.py | 128 ++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 11 deletions(-) create mode 100644 examples/embd-input/.gitignore create mode 100644 examples/embd-input/minigpt4.py diff --git a/.gitignore b/.gitignore index 9116729fd..4fccec31b 100644 --- a/.gitignore +++ b/.gitignore @@ -40,7 +40,7 @@ models/* /vdot /server /Pipfile -/embd_input_test +/embd-input-test /libllama.so build-info.h arm_neon.h diff --git a/examples/embd-input/.gitignore b/examples/embd-input/.gitignore new file mode 100644 index 000000000..87ef68771 --- /dev/null +++ b/examples/embd-input/.gitignore @@ -0,0 +1,4 @@ +PandaGPT +MiniGPT-4 +*.pth + diff --git a/examples/embd-input/README.md b/examples/embd-input/README.md index eb1095f24..02d028f26 100644 --- a/examples/embd-input/README.md +++ b/examples/embd-input/README.md @@ -1,17 +1,17 @@ ### Examples for input embedding directly ## Requirement -build `libembd_input.so` +build `libembdinput.so` run the following comman in main dir (../../). ``` make ``` -## LLAVA example (llava.py) +## [LLaVA](https://github.com/haotian-liu/LLaVA/) example (llava.py) -1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/) -2. convert it to ggml format -3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin) +1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/). +2. Convert it to ggml format. +3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin). ``` import torch @@ -23,10 +23,12 @@ dic = torch.load(bin_path) used_key = ["model.mm_projector.weight","model.mm_projector.bias"] torch.save({k: dic[k] for k in used_key}, pth_path) ``` +4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`. -## PandaGPT example (panda_gpt.py) -1. Obtian PandaGPT lora model. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format. +## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py) + +1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format. The `adapter_config.json` is ``` { @@ -40,8 +42,22 @@ The `adapter_config.json` is "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"] } ``` -2. papare the `vicuna` v0 model. -3. obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model. +2. Papare the `vicuna` v0 model. +3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model. 4. Clone the PandaGPT source. -5. check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py. +``` +git clone https://github.com/yxuansu/PandaGPT +``` +5. Install the requirement of PandaGPT. +6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py. +## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py) + +1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`. +2. Clone the MiniGPT-4 source. +``` +git clone https://github.com/Vision-CAIR/MiniGPT-4/ +``` +3. Install the requirement of PandaGPT. +4. Papare the `vicuna` v0 model. +5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`. diff --git a/examples/embd-input/minigpt4.py b/examples/embd-input/minigpt4.py new file mode 100644 index 000000000..8e98f8517 --- /dev/null +++ b/examples/embd-input/minigpt4.py @@ -0,0 +1,128 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch +from PIL import Image + +minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4") +sys.path.insert(0, minigpt4_path) +from minigpt4.models.blip2 import Blip2Base +from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor + + +class MiniGPT4(Blip2Base): + """ + MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4 + """ + def __init__(self, + args, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp32", + freeze_vit=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0 + ): + super().__init__() + self.img_size = img_size + self.low_resource = low_resource + self.preprocessor = Blip2ImageEvalProcessor(img_size) + + print('Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + print('Loading VIT Done') + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + print('Loading Q-Former Done') + self.llama_proj = nn.Linear( + self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size + ) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.model = MyModel(["main", *args]) + # system promt + self.model.eval_string("Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions." + "###") + + def encode_img(self, image): + image = self.preprocessor(image) + image = image.unsqueeze(0) + device = image.device + if self.low_resource: + self.vit_to_cpu() + image = image.to("cpu") + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama + + def load_projection(self, path): + state = torch.load(path)["model"] + self.llama_proj.load_state_dict({ + "weight": state["llama_proj.weight"], + "bias": state["llama_proj.bias"]}) + + def chat(self, question): + self.model.eval_string("Human: ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + def chat_with_image(self, image, question): + with torch.no_grad(): + embd_image = self.encode_img(image) + embd_image = embd_image.cpu().numpy()[0] + self.model.eval_string("Human: ") + self.model.eval_float(embd_image.T) + self.model.eval_string(" ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + +if __name__=="__main__": + a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"]) + a.load_projection(os.path.join( + os.path.dirname(__file__) , + "pretrained_minigpt4.pth")) + respose = a.chat_with_image( + Image.open("./media/llama1-logo.png").convert('RGB'), + "what is the text in the picture?") + a.chat("what is the color of it?")