add example of PandaGPT

This commit is contained in:
ningshanwutuobang 2023-06-20 22:57:21 +08:00
parent 46490c7ad7
commit 53dfbbf553
5 changed files with 166 additions and 17 deletions

View file

@ -113,6 +113,10 @@ with open(output_path, "wb") as fout:
write_file_header(fout, params)
for k, v in model.items():
if k.endswith(".default.weight"):
k = k.replace(".default.weight", ".weight")
if k in ["llama_proj.weight", "llama_proj.bias"]:
continue
if k.endswith("lora_A.weight"):
if v.dtype != torch.float16 and v.dtype != torch.float32:
v = v.float()
@ -120,7 +124,7 @@ with open(output_path, "wb") as fout:
else:
v = v.float()
t = v.numpy()
t = v.detach().numpy()
tname = translate_tensor_name(k)
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
write_tensor_header(fout, tname, t.shape, t.dtype)

View file

@ -1,14 +1,17 @@
### Examples for input embedding directly
## Requirement
build `libembd_input.so`
run the following comman in main dir (../../).
```
make
```
## LLAVA example (llava.py)
1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/)
2. build `libembd_input.so`
```
make
```
3. convert it to ggml format
4. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
2. convert it to ggml format
3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
```
import torch
@ -21,3 +24,24 @@ used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
torch.save({k: dic[k] for k in used_key}, pth_path)
```
## PandaGPT example (panda_gpt.py)
1. Obtian PandaGPT lora model. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
The `adapter_config.json` is
```
{
"peft_type": "LORA",
"fan_in_fan_out": false,
"bias": null,
"modules_to_save": null,
"r": 32,
"lora_alpha": 32,
"lora_dropout": 0.1,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
}
```
2. papare the `vicuna` v0 model.
3. obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
4. Clone the PandaGPT source.
5. check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.

View file

@ -33,6 +33,35 @@ class MyModel:
s = libc.sampling(self.model)
return s
def generate(self, end="</s>"):
ret = b""
end = end.encode()
for _ in range(500):
tmp = self.sampling() # .decode()
if (ret+tmp).endswith(end):
break
ret += tmp
return ret.decode()
def stream_generate(self, end="</s>"):
ret = b""
end = end.encode()
head = b""
for _ in range(500):
tmp = self.sampling() # .decode()
ret += tmp
try:
text = (head + tmp).decode()
print(text, end="")
head = b""
except:
head += text
if ret.endswith(end):
break
print("")
return ret.decode()
if __name__ == "__main__":
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
# print(model)

View file

@ -31,7 +31,7 @@ class Llava:
self.model.eval_string("user: ")
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
return self.sampling()
return self.model.generate()
def chat_with_image(self, image, question):
with torch.no_grad():
@ -49,16 +49,8 @@ class Llava:
self.model.eval_token(32003-1) # im_end
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
return self.sampling()
return self.model.generate()
def sampling(self):
ret = b""
for _ in range(500):
tmp = self.model.sampling() # .decode()
if tmp == b"</s>":
break
ret += tmp
return ret.decode()
if __name__=="__main__":
# model form liuhaotian/LLaVA-13b-delta-v1-1

View file

@ -0,0 +1,100 @@
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
# use PandaGPT path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "PandaGPT","code","model"))
from ImageBind.models import imagebind_model
from ImageBind import data
imagebind_ckpt_path = "./models/panda_gpt/"
ModalityType = imagebind_model.ModalityType
max_tgt_len = 400
class PandaGPT:
def __init__(self, args):
self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
self.visual_encoder.eval()
self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
self.max_tgt_len = max_tgt_len
self.model = MyModel(["main", *args])
self.generated_text = ""
self.device = "cpu"
def load_projection(self, path):
state = torch.load(path, map_location="cpu")
self.llama_proj.load_state_dict({
"weight": state["llama_proj.weight"],
"bias": state["llama_proj.bias"]})
def chat(self, question):
if self.generated_text == "":
self.model.eval_string("###")
self.model.eval_string(" Human: ")
self.model.eval_string(question)
self.model.eval_string("\n### Assistant:")
ret = self.model.stream_generate(end="###")
self.generated_text += ret
return ret
def chat_with_image(self, inputs, question):
if self.generated_text == "":
self.model.eval_string("###")
self.model.eval_string(" Human: <Img>")
embds = self.extract_multimoal_feature(inputs)
for i in embds:
self.model.eval_float(i.T)
self.model.eval_string("</Img> " + question + "\n### Assistant:")
ret = self.model.stream_generate(end="###")
self.generated_text += ret
return ret
def extract_multimoal_feature(self, inputs):
features = []
for key in ["image", "audio", "video", "thermal"]:
if key + "_paths" in inputs:
embeds = self.encode_data(key, inputs[key+"_paths"])
features.append(embeds)
return features
def encode_data(self, data_type, data_paths):
type_map = {
"image": ModalityType.VISION,
"audio": ModalityType.AUDIO,
"video": ModalityType.VISION,
"thermal": ModalityType.THERMAL,
}
load_map = {
"image": data.load_and_transform_vision_data,
"audio": data.load_and_transform_audio_data,
"video": data.load_and_transform_video_data,
"thermal": data.load_and_transform_thermal_data
}
load_function = load_map[data_type]
key = type_map[data_type]
inputs = {key: load_function(data_paths, self.device)}
with torch.no_grad():
embeddings = self.visual_encoder(inputs)
embeds = embeddings[key]
embeds = self.llama_proj(embeds).cpu().numpy()
return embeds
if __name__=="__main__":
# model form liuhaotian/LLaVA-13b-delta-v1-1
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
# Also here can use pytorch_model-00003-of-00003.bin directly.
a.load_projection("./models/panda_gpt/adapter_model.bin")
a.chat_with_image(
{"image_paths": ["./media/llama1-logo.png"]},
"what is the text in the picture? 'llama' or 'lambda'?")
a.chat("what is the color of it?")