add READMD for llava.py
This commit is contained in:
parent
8cea3ab9e5
commit
4f1aa3cc76
4 changed files with 70 additions and 15 deletions
20
examples/embd_input/README.md
Normal file
20
examples/embd_input/README.md
Normal file
|
@ -0,0 +1,20 @@
|
|||
### Examples for input embedding directly
|
||||
|
||||
## LLAVA example (llava.py)
|
||||
|
||||
1. obtian llava model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/)
|
||||
2. convert it to ggml format
|
||||
3. llava_projection.pth is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin)
|
||||
|
||||
```
|
||||
import torch
|
||||
|
||||
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
|
||||
pth_path = "./examples/embd_input/llava_projection.pth"
|
||||
|
||||
dic = torch.load(bin_path)
|
||||
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
|
||||
torch.save({k: dic[k] for k in used_key}, pth_path)
|
||||
```
|
||||
|
||||
|
|
@ -272,7 +272,10 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
|||
const char* sampling(struct MyModel* mymodel) {
|
||||
llama_context* ctx = mymodel->ctx;
|
||||
int id = sampling_id(mymodel);
|
||||
std::string ret = llama_token_to_str(ctx, id);
|
||||
|
||||
std::string ret;
|
||||
if (id == llama_token_eos()) ret = "</s>";
|
||||
else ret = llama_token_to_str(ctx, id);
|
||||
eval_id(mymodel, id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ int main(int argc, char** argv) {
|
|||
for (int i=0;i < 500; i++) {
|
||||
// int id = sampling_id(mymodel);
|
||||
tmp = sampling(mymodel);
|
||||
if (strlen(tmp) == 0) break;
|
||||
if (strcmp(tmp, "</s>")==0) break;
|
||||
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id));
|
||||
fflush(stdout);
|
||||
// eval_id(mymodel, id);
|
||||
|
|
|
@ -7,40 +7,72 @@ from torch import nn
|
|||
import torch
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor
|
||||
from PIL import Image
|
||||
|
||||
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
|
||||
vision_tower = "openai/clip-vit-large-patch14"
|
||||
select_hidden_state_layer = -2
|
||||
# (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
image_token_len = (224//14)**2
|
||||
|
||||
class Llava:
|
||||
def __init__(self):
|
||||
def __init__(self, args):
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
||||
self.mm_projector = nn.Linear(1024, 5120)
|
||||
self.model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
|
||||
self.model = MyModel(["main", *args])
|
||||
|
||||
def load_projection(self, path):
|
||||
state = torch.load(path)
|
||||
self.mm_projector.load_state_dict({
|
||||
"weight": state["model.mm_projector.weight"],
|
||||
"bias": state["model.mm_projector.bias"]})
|
||||
|
||||
def chat(self, question):
|
||||
self.model.eval_string("user: ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
return self.sampling()
|
||||
|
||||
def chat_with_image(self, image, question):
|
||||
with torch.no_grad():
|
||||
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
||||
image_feature = select_hidden_state[:, 1:]
|
||||
embd_image = self.mm_projector(image_feature)
|
||||
embd_image = embd_image.cpu().numpy()
|
||||
embd_image = embd_image.cpu().numpy()[0]
|
||||
self.model.eval_string("user: ")
|
||||
# print(embd_image.shape)
|
||||
self.model.eval_token(32003-2) # im_start
|
||||
self.model.eval_float(embd_image.T)
|
||||
for i in range(image_token_len-embd_image.shape[0]):
|
||||
self.model.eval_token(32003-3) # im_patch
|
||||
self.model.eval_token(32003-1) # im_end
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
ret = ""
|
||||
return self.sampling()
|
||||
|
||||
def sampling(self):
|
||||
ret = b""
|
||||
for _ in range(500):
|
||||
tmp = self.model.sampling().decode()
|
||||
if tmp == "":
|
||||
tmp = self.model.sampling() # .decode()
|
||||
if tmp == b"</s>":
|
||||
break
|
||||
ret += tmp
|
||||
return ret
|
||||
return ret.decode()
|
||||
|
||||
if __name__=="__main__":
|
||||
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||||
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
|
||||
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
||||
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
||||
a.load_projection(os.path.join(
|
||||
os.path.dirname(__file__) ,
|
||||
"llava_projetion.pth"))
|
||||
respose = a.chat_with_image(
|
||||
Image.open("./media/llama1-logo.png").convert('RGB'),
|
||||
"what is the text in the picture?")
|
||||
print(respose)
|
||||
print(a.chat("what is the color of it?"))
|
||||
|
||||
a = Llava()
|
||||
state = torch.load(os.path.dirname(__file__) + "/a.pth")
|
||||
a.mm_projector.load_state_dict({"weight": state["model.mm_projector.weight"], "bias": state["model.mm_projector.bias"]})
|
||||
print(a.chat_with_image(Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?"))
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue