add READMD for llava.py

This commit is contained in:
ningshanwutuobang 2023-06-17 16:41:37 +08:00
parent 8cea3ab9e5
commit 4f1aa3cc76
4 changed files with 70 additions and 15 deletions

View file

@ -0,0 +1,20 @@
### Examples for input embedding directly
## LLAVA example (llava.py)
1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/)
2. convert it to ggml format
3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
```
import torch
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
pth_path = "./examples/embd_input/llava_projection.pth"
dic = torch.load(bin_path)
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
torch.save({k: dic[k] for k in used_key}, pth_path)
```

View file

@ -272,7 +272,10 @@ llama_token sampling_id(struct MyModel* mymodel) {
const char* sampling(struct MyModel* mymodel) { const char* sampling(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx; llama_context* ctx = mymodel->ctx;
int id = sampling_id(mymodel); int id = sampling_id(mymodel);
std::string ret = llama_token_to_str(ctx, id);
std::string ret;
if (id == llama_token_eos()) ret = "</s>";
else ret = llama_token_to_str(ctx, id);
eval_id(mymodel, id); eval_id(mymodel, id);
return ret.c_str(); return ret.c_str();
} }

View file

@ -25,7 +25,7 @@ int main(int argc, char** argv) {
for (int i=0;i < 500; i++) { for (int i=0;i < 500; i++) {
// int id = sampling_id(mymodel); // int id = sampling_id(mymodel);
tmp = sampling(mymodel); tmp = sampling(mymodel);
if (strlen(tmp) == 0) break; if (strcmp(tmp, "</s>")==0) break;
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id)); printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id));
fflush(stdout); fflush(stdout);
// eval_id(mymodel, id); // eval_id(mymodel, id);

View file

@ -7,40 +7,72 @@ from torch import nn
import torch import torch
from transformers import CLIPVisionModel, CLIPImageProcessor from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image from PIL import Image
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
vision_tower = "openai/clip-vit-large-patch14" vision_tower = "openai/clip-vit-large-patch14"
select_hidden_state_layer = -2
# (vision_config.image_size // vision_config.patch_size) ** 2
image_token_len = (224//14)**2
class Llava: class Llava:
def __init__(self): def __init__(self, args):
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower) self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower) self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
self.mm_projector = nn.Linear(1024, 5120) self.mm_projector = nn.Linear(1024, 5120)
self.model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) self.model = MyModel(["main", *args])
def load_projection(self, path):
state = torch.load(path)
self.mm_projector.load_state_dict({
"weight": state["model.mm_projector.weight"],
"bias": state["model.mm_projector.bias"]})
def chat(self, question):
self.model.eval_string("user: ")
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
return self.sampling()
def chat_with_image(self, image, question): def chat_with_image(self, image, question):
with torch.no_grad(): with torch.no_grad():
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True) image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
image_feature = select_hidden_state[:, 1:] image_feature = select_hidden_state[:, 1:]
embd_image = self.mm_projector(image_feature) embd_image = self.mm_projector(image_feature)
embd_image = embd_image.cpu().numpy() embd_image = embd_image.cpu().numpy()[0]
self.model.eval_string("user: ") self.model.eval_string("user: ")
# print(embd_image.shape) self.model.eval_token(32003-2) # im_start
self.model.eval_float(embd_image.T) self.model.eval_float(embd_image.T)
for i in range(image_token_len-embd_image.shape[0]):
self.model.eval_token(32003-3) # im_patch
self.model.eval_token(32003-1) # im_end
self.model.eval_string(question) self.model.eval_string(question)
self.model.eval_string("\nassistant: ") self.model.eval_string("\nassistant: ")
ret = "" return self.sampling()
def sampling(self):
ret = b""
for _ in range(500): for _ in range(500):
tmp = self.model.sampling().decode() tmp = self.model.sampling() # .decode()
if tmp == "": if tmp == b"</s>":
break break
ret += tmp ret += tmp
return ret return ret.decode()
if __name__=="__main__":
# model form liuhaotian/LLaVA-13b-delta-v1-1
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
# Also here can use pytorch_model-00003-of-00003.bin directly.
a.load_projection(os.path.join(
os.path.dirname(__file__) ,
"llava_projetion.pth"))
respose = a.chat_with_image(
Image.open("./media/llama1-logo.png").convert('RGB'),
"what is the text in the picture?")
print(respose)
print(a.chat("what is the color of it?"))
a = Llava()
state = torch.load(os.path.dirname(__file__) + "/a.pth")
a.mm_projector.load_state_dict({"weight": state["model.mm_projector.weight"], "bias": state["model.mm_projector.bias"]})
print(a.chat_with_image(Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?"))