add example of PandaGPT
This commit is contained in:
parent
46490c7ad7
commit
53dfbbf553
5 changed files with 166 additions and 17 deletions
|
@ -113,6 +113,10 @@ with open(output_path, "wb") as fout:
|
|||
|
||||
write_file_header(fout, params)
|
||||
for k, v in model.items():
|
||||
if k.endswith(".default.weight"):
|
||||
k = k.replace(".default.weight", ".weight")
|
||||
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
||||
continue
|
||||
if k.endswith("lora_A.weight"):
|
||||
if v.dtype != torch.float16 and v.dtype != torch.float32:
|
||||
v = v.float()
|
||||
|
@ -120,7 +124,7 @@ with open(output_path, "wb") as fout:
|
|||
else:
|
||||
v = v.float()
|
||||
|
||||
t = v.numpy()
|
||||
t = v.detach().numpy()
|
||||
tname = translate_tensor_name(k)
|
||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
### Examples for input embedding directly
|
||||
|
||||
## Requirement
|
||||
build `libembd_input.so`
|
||||
run the following comman in main dir (../../).
|
||||
```
|
||||
make
|
||||
```
|
||||
|
||||
## LLAVA example (llava.py)
|
||||
|
||||
1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/)
|
||||
2. build `libembd_input.so`
|
||||
```
|
||||
make
|
||||
```
|
||||
3. convert it to ggml format
|
||||
4. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
|
||||
2. convert it to ggml format
|
||||
3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
|
||||
|
||||
```
|
||||
import torch
|
||||
|
@ -21,3 +24,24 @@ used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
|
|||
torch.save({k: dic[k] for k in used_key}, pth_path)
|
||||
```
|
||||
|
||||
## PandaGPT example (panda_gpt.py)
|
||||
|
||||
1. Obtian PandaGPT lora model. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
|
||||
The `adapter_config.json` is
|
||||
```
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"fan_in_fan_out": false,
|
||||
"bias": null,
|
||||
"modules_to_save": null,
|
||||
"r": 32,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
}
|
||||
```
|
||||
2. papare the `vicuna` v0 model.
|
||||
3. obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
|
||||
4. Clone the PandaGPT source.
|
||||
5. check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.
|
||||
|
||||
|
|
|
@ -33,6 +33,35 @@ class MyModel:
|
|||
s = libc.sampling(self.model)
|
||||
return s
|
||||
|
||||
def generate(self, end="</s>"):
|
||||
ret = b""
|
||||
end = end.encode()
|
||||
for _ in range(500):
|
||||
tmp = self.sampling() # .decode()
|
||||
if (ret+tmp).endswith(end):
|
||||
break
|
||||
ret += tmp
|
||||
return ret.decode()
|
||||
|
||||
def stream_generate(self, end="</s>"):
|
||||
ret = b""
|
||||
end = end.encode()
|
||||
head = b""
|
||||
for _ in range(500):
|
||||
tmp = self.sampling() # .decode()
|
||||
ret += tmp
|
||||
try:
|
||||
text = (head + tmp).decode()
|
||||
print(text, end="")
|
||||
head = b""
|
||||
except:
|
||||
head += text
|
||||
if ret.endswith(end):
|
||||
break
|
||||
print("")
|
||||
return ret.decode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
|
||||
# print(model)
|
||||
|
|
|
@ -31,7 +31,7 @@ class Llava:
|
|||
self.model.eval_string("user: ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
return self.sampling()
|
||||
return self.model.generate()
|
||||
|
||||
def chat_with_image(self, image, question):
|
||||
with torch.no_grad():
|
||||
|
@ -49,16 +49,8 @@ class Llava:
|
|||
self.model.eval_token(32003-1) # im_end
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
return self.sampling()
|
||||
return self.model.generate()
|
||||
|
||||
def sampling(self):
|
||||
ret = b""
|
||||
for _ in range(500):
|
||||
tmp = self.model.sampling() # .decode()
|
||||
if tmp == b"</s>":
|
||||
break
|
||||
ret += tmp
|
||||
return ret.decode()
|
||||
|
||||
if __name__=="__main__":
|
||||
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||||
|
|
100
examples/embd_input/panda_gpt.py
Normal file
100
examples/embd_input/panda_gpt.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from embd_input import MyModel
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
# use PandaGPT path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "PandaGPT","code","model"))
|
||||
from ImageBind.models import imagebind_model
|
||||
from ImageBind import data
|
||||
|
||||
imagebind_ckpt_path = "./models/panda_gpt/"
|
||||
ModalityType = imagebind_model.ModalityType
|
||||
max_tgt_len = 400
|
||||
|
||||
class PandaGPT:
|
||||
def __init__(self, args):
|
||||
self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
|
||||
self.visual_encoder.eval()
|
||||
self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
|
||||
self.max_tgt_len = max_tgt_len
|
||||
self.model = MyModel(["main", *args])
|
||||
self.generated_text = ""
|
||||
self.device = "cpu"
|
||||
|
||||
def load_projection(self, path):
|
||||
state = torch.load(path, map_location="cpu")
|
||||
self.llama_proj.load_state_dict({
|
||||
"weight": state["llama_proj.weight"],
|
||||
"bias": state["llama_proj.bias"]})
|
||||
|
||||
def chat(self, question):
|
||||
if self.generated_text == "":
|
||||
self.model.eval_string("###")
|
||||
self.model.eval_string(" Human: ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\n### Assistant:")
|
||||
ret = self.model.stream_generate(end="###")
|
||||
self.generated_text += ret
|
||||
return ret
|
||||
|
||||
def chat_with_image(self, inputs, question):
|
||||
if self.generated_text == "":
|
||||
self.model.eval_string("###")
|
||||
self.model.eval_string(" Human: <Img>")
|
||||
embds = self.extract_multimoal_feature(inputs)
|
||||
for i in embds:
|
||||
self.model.eval_float(i.T)
|
||||
self.model.eval_string("</Img> " + question + "\n### Assistant:")
|
||||
ret = self.model.stream_generate(end="###")
|
||||
self.generated_text += ret
|
||||
return ret
|
||||
|
||||
def extract_multimoal_feature(self, inputs):
|
||||
features = []
|
||||
for key in ["image", "audio", "video", "thermal"]:
|
||||
if key + "_paths" in inputs:
|
||||
embeds = self.encode_data(key, inputs[key+"_paths"])
|
||||
features.append(embeds)
|
||||
return features
|
||||
|
||||
def encode_data(self, data_type, data_paths):
|
||||
|
||||
type_map = {
|
||||
"image": ModalityType.VISION,
|
||||
"audio": ModalityType.AUDIO,
|
||||
"video": ModalityType.VISION,
|
||||
"thermal": ModalityType.THERMAL,
|
||||
}
|
||||
load_map = {
|
||||
"image": data.load_and_transform_vision_data,
|
||||
"audio": data.load_and_transform_audio_data,
|
||||
"video": data.load_and_transform_video_data,
|
||||
"thermal": data.load_and_transform_thermal_data
|
||||
}
|
||||
|
||||
load_function = load_map[data_type]
|
||||
key = type_map[data_type]
|
||||
|
||||
inputs = {key: load_function(data_paths, self.device)}
|
||||
with torch.no_grad():
|
||||
embeddings = self.visual_encoder(inputs)
|
||||
embeds = embeddings[key]
|
||||
embeds = self.llama_proj(embeds).cpu().numpy()
|
||||
return embeds
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||||
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
|
||||
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
||||
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
||||
a.load_projection("./models/panda_gpt/adapter_model.bin")
|
||||
a.chat_with_image(
|
||||
{"image_paths": ["./media/llama1-logo.png"]},
|
||||
"what is the text in the picture? 'llama' or 'lambda'?")
|
||||
a.chat("what is the color of it?")
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue