add examples for llava.py

This commit is contained in:
ningshanwutuobang 2023-06-08 04:50:31 +08:00
parent 6ed4893391
commit 8cea3ab9e5
2 changed files with 48 additions and 2 deletions

View file

@ -21,7 +21,7 @@ class MyModel:
libc.free_mymodel(self.model) libc.free_mymodel(self.model)
def eval_float(self, x): def eval_float(self, x):
libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[0]) libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
def eval_string(self, x): def eval_string(self, x):
libc.eval_string(self.model, x.encode()) # c_char_p(x.encode())) libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
@ -38,7 +38,7 @@ if __name__ == "__main__":
# print(model) # print(model)
model.eval_string("""user: what is the color of the flag of UN?""") model.eval_string("""user: what is the color of the flag of UN?""")
# model.eval_token(100) # model.eval_token(100)
x = np.random.random((10, 5120))# , dtype=np.float32) x = np.random.random((5120,10))# , dtype=np.float32)
model.eval_float(x) model.eval_float(x)
model.eval_string("""assistant:""") model.eval_string("""assistant:""")
# print(x[0,0], x[0,1],x[1,0]) # print(x[0,0], x[0,1],x[1,0])

View file

@ -0,0 +1,46 @@
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
vision_tower = "openai/clip-vit-large-patch14"
class Llava:
def __init__(self):
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
self.mm_projector = nn.Linear(1024, 5120)
self.model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
def chat_with_image(self, image, question):
with torch.no_grad():
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
image_feature = select_hidden_state[:, 1:]
embd_image = self.mm_projector(image_feature)
embd_image = embd_image.cpu().numpy()
self.model.eval_string("user: ")
# print(embd_image.shape)
self.model.eval_float(embd_image.T)
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
ret = ""
for _ in range(500):
tmp = self.model.sampling().decode()
if tmp == "":
break
ret += tmp
return ret
a = Llava()
state = torch.load(os.path.dirname(__file__) + "/a.pth")
a.mm_projector.load_state_dict({"weight": state["model.mm_projector.weight"], "bias": state["model.mm_projector.bias"]})
print(a.chat_with_image(Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?"))