add READMD for llava.py
This commit is contained in:
parent
8cea3ab9e5
commit
4f1aa3cc76
4 changed files with 70 additions and 15 deletions
20
examples/embd_input/README.md
Normal file
20
examples/embd_input/README.md
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
### Examples for input embedding directly
|
||||||
|
|
||||||
|
## LLAVA example (llava.py)
|
||||||
|
|
||||||
|
1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/)
|
||||||
|
2. convert it to ggml format
|
||||||
|
3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
|
||||||
|
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
|
||||||
|
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
|
||||||
|
pth_path = "./examples/embd_input/llava_projection.pth"
|
||||||
|
|
||||||
|
dic = torch.load(bin_path)
|
||||||
|
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
|
||||||
|
torch.save({k: dic[k] for k in used_key}, pth_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -272,7 +272,10 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
||||||
const char* sampling(struct MyModel* mymodel) {
|
const char* sampling(struct MyModel* mymodel) {
|
||||||
llama_context* ctx = mymodel->ctx;
|
llama_context* ctx = mymodel->ctx;
|
||||||
int id = sampling_id(mymodel);
|
int id = sampling_id(mymodel);
|
||||||
std::string ret = llama_token_to_str(ctx, id);
|
|
||||||
|
std::string ret;
|
||||||
|
if (id == llama_token_eos()) ret = "</s>";
|
||||||
|
else ret = llama_token_to_str(ctx, id);
|
||||||
eval_id(mymodel, id);
|
eval_id(mymodel, id);
|
||||||
return ret.c_str();
|
return ret.c_str();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ int main(int argc, char** argv) {
|
||||||
for (int i=0;i < 500; i++) {
|
for (int i=0;i < 500; i++) {
|
||||||
// int id = sampling_id(mymodel);
|
// int id = sampling_id(mymodel);
|
||||||
tmp = sampling(mymodel);
|
tmp = sampling(mymodel);
|
||||||
if (strlen(tmp) == 0) break;
|
if (strcmp(tmp, "</s>")==0) break;
|
||||||
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id));
|
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id));
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
// eval_id(mymodel, id);
|
// eval_id(mymodel, id);
|
||||||
|
|
|
@ -7,40 +7,72 @@ from torch import nn
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPVisionModel, CLIPImageProcessor
|
from transformers import CLIPVisionModel, CLIPImageProcessor
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
|
||||||
vision_tower = "openai/clip-vit-large-patch14"
|
vision_tower = "openai/clip-vit-large-patch14"
|
||||||
|
select_hidden_state_layer = -2
|
||||||
|
# (vision_config.image_size // vision_config.patch_size) ** 2
|
||||||
|
image_token_len = (224//14)**2
|
||||||
|
|
||||||
class Llava:
|
class Llava:
|
||||||
def __init__(self):
|
def __init__(self, args):
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
||||||
self.mm_projector = nn.Linear(1024, 5120)
|
self.mm_projector = nn.Linear(1024, 5120)
|
||||||
self.model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
|
self.model = MyModel(["main", *args])
|
||||||
|
|
||||||
|
def load_projection(self, path):
|
||||||
|
state = torch.load(path)
|
||||||
|
self.mm_projector.load_state_dict({
|
||||||
|
"weight": state["model.mm_projector.weight"],
|
||||||
|
"bias": state["model.mm_projector.bias"]})
|
||||||
|
|
||||||
|
def chat(self, question):
|
||||||
|
self.model.eval_string("user: ")
|
||||||
|
self.model.eval_string(question)
|
||||||
|
self.model.eval_string("\nassistant: ")
|
||||||
|
return self.sampling()
|
||||||
|
|
||||||
def chat_with_image(self, image, question):
|
def chat_with_image(self, image, question):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||||
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
||||||
select_hidden_state_layer = -2
|
|
||||||
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
||||||
image_feature = select_hidden_state[:, 1:]
|
image_feature = select_hidden_state[:, 1:]
|
||||||
embd_image = self.mm_projector(image_feature)
|
embd_image = self.mm_projector(image_feature)
|
||||||
embd_image = embd_image.cpu().numpy()
|
embd_image = embd_image.cpu().numpy()[0]
|
||||||
self.model.eval_string("user: ")
|
self.model.eval_string("user: ")
|
||||||
# print(embd_image.shape)
|
self.model.eval_token(32003-2) # im_start
|
||||||
self.model.eval_float(embd_image.T)
|
self.model.eval_float(embd_image.T)
|
||||||
|
for i in range(image_token_len-embd_image.shape[0]):
|
||||||
|
self.model.eval_token(32003-3) # im_patch
|
||||||
|
self.model.eval_token(32003-1) # im_end
|
||||||
self.model.eval_string(question)
|
self.model.eval_string(question)
|
||||||
self.model.eval_string("\nassistant: ")
|
self.model.eval_string("\nassistant: ")
|
||||||
ret = ""
|
return self.sampling()
|
||||||
|
|
||||||
|
def sampling(self):
|
||||||
|
ret = b""
|
||||||
for _ in range(500):
|
for _ in range(500):
|
||||||
tmp = self.model.sampling().decode()
|
tmp = self.model.sampling() # .decode()
|
||||||
if tmp == "":
|
if tmp == b"</s>":
|
||||||
break
|
break
|
||||||
ret += tmp
|
ret += tmp
|
||||||
return ret
|
return ret.decode()
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||||||
|
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
|
||||||
|
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
||||||
|
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
||||||
|
a.load_projection(os.path.join(
|
||||||
|
os.path.dirname(__file__) ,
|
||||||
|
"llava_projetion.pth"))
|
||||||
|
respose = a.chat_with_image(
|
||||||
|
Image.open("./media/llama1-logo.png").convert('RGB'),
|
||||||
|
"what is the text in the picture?")
|
||||||
|
print(respose)
|
||||||
|
print(a.chat("what is the color of it?"))
|
||||||
|
|
||||||
a = Llava()
|
|
||||||
state = torch.load(os.path.dirname(__file__) + "/a.pth")
|
|
||||||
a.mm_projector.load_state_dict({"weight": state["model.mm_projector.weight"], "bias": state["model.mm_projector.bias"]})
|
|
||||||
print(a.chat_with_image(Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?"))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue