simple example works
This commit is contained in:
parent
4897ff61c6
commit
6089b0a50a
7 changed files with 111 additions and 33 deletions
|
@ -32,6 +32,7 @@ llama_img * load_image_from_file(const char * fname) {
|
|||
// }
|
||||
// printf("\n");
|
||||
llama_img * result = llama_img_alloc(nx, ny);
|
||||
memcpy(result->data, bytes, nx*ny*nc);
|
||||
memcpy(result->data, img, nx*ny*3);
|
||||
stbi_image_free(img);
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import json
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
from transformers import AutoConfig
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
|
@ -67,6 +68,7 @@ class Model:
|
|||
is_lora: bool
|
||||
|
||||
# for vision model
|
||||
preprocessor_config: dict[str, Any] | None = None
|
||||
vparams: dict[str, Any] | None = None
|
||||
v_tensor_map: gguf.TensorNameMap
|
||||
v_tensor_names: set[str] | None
|
||||
|
@ -100,6 +102,7 @@ class Model:
|
|||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
|
||||
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
|
@ -463,8 +466,20 @@ class Model:
|
|||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
if "text_config" in hparams:
|
||||
hparams = {**hparams, **hparams["text_config"]}
|
||||
text_config = hparams["text_config"]
|
||||
if "_name_or_path" in text_config:
|
||||
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
|
||||
hparams = {**text_config, **hparams}
|
||||
return hparams
|
||||
|
||||
@staticmethod
|
||||
def load_preprocessor_config(dir_model: Path):
|
||||
file_path = dir_model / "preprocessor_config.json"
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||
|
@ -1574,7 +1589,7 @@ class LlamaModel(Model):
|
|||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
# For vision model
|
||||
if self.vparams is not None:
|
||||
if self.vparams is not None and self.preprocessor_config is not None:
|
||||
self.gguf_writer.add_vision_type("clip")
|
||||
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
|
||||
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
|
||||
|
@ -1583,14 +1598,13 @@ class LlamaModel(Model):
|
|||
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
|
||||
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
|
||||
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
|
||||
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
|
||||
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
|
||||
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
|
||||
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
|
||||
# TODO: should not hardcode these, but they are currently missing from config.json
|
||||
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
|
||||
self.gguf_writer.add_vision_clip_max_position_embeddings(577)
|
||||
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
|
||||
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
self.gguf_writer.add_vision_clip_image_mean(default_image_mean)
|
||||
self.gguf_writer.add_vision_clip_image_std(default_image_std)
|
||||
|
||||
@staticmethod
|
||||
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||
|
@ -1606,8 +1620,11 @@ class LlamaModel(Model):
|
|||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
# For vision model
|
||||
if name.startswith("language_model"):
|
||||
name = name.replace("language_model.", "")
|
||||
if "post_layernorm" in name:
|
||||
return [] # skip post_layernorm
|
||||
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
|
|
|
@ -66,12 +66,41 @@ int main(int argc, char ** argv) {
|
|||
|
||||
|
||||
// TODO: this is for testing; DELETE ME
|
||||
llama_img_batch ibatch;
|
||||
ibatch.n_imgs = 1;
|
||||
ibatch.imgs = (llama_img **) malloc(1024);
|
||||
ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png");
|
||||
llama_vision_encode(ctx, &ibatch);
|
||||
return 0;
|
||||
int n_cur = 0;
|
||||
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
|
||||
{
|
||||
llama_img_batch ibatch;
|
||||
ibatch.n_imgs = 1;
|
||||
ibatch.imgs = (llama_img **) malloc(1024);
|
||||
ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");
|
||||
llama_vision_encode(ctx, &ibatch);
|
||||
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||
int n_imgs = ibatch.n_imgs;
|
||||
int n_embd = llama_n_embd(model);
|
||||
int n_patches = llama_vision_n_patches(ctx);
|
||||
printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches);
|
||||
float * output_img = llama_vision_get_embeddings(ctx, 0);
|
||||
|
||||
n_cur += tokens.size();
|
||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
llama_batch_clear(batch);
|
||||
for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]);
|
||||
llama_batch_clear(batch);
|
||||
batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
n_cur += n_embd*n_imgs;
|
||||
}
|
||||
params.prompt = "\nwhat did you see?\nASSISTANT:";
|
||||
|
||||
|
||||
|
||||
|
@ -108,7 +137,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < tokens_list.size(); i++) {
|
||||
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
|
||||
//llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
|
||||
if (i == 0) continue;
|
||||
llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false);
|
||||
n_cur++;
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
|
@ -121,18 +153,18 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// main loop
|
||||
|
||||
int n_cur = batch.n_tokens;
|
||||
//int n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_predict) {
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
// sample the next token
|
||||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
if (llama_token_is_eog(model, new_token_id)) {
|
||||
LOG("\n");
|
||||
|
||||
break;
|
||||
|
|
|
@ -903,6 +903,8 @@ extern "C" {
|
|||
// get output embeddings, to be put into language batch
|
||||
LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx);
|
||||
|
||||
LLAMA_API int32_t llama_vision_n_patches(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
|
||||
// export clip_image_u8 to bmp file for debugging
|
||||
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
|
||||
static int bmp_export(const clip_image_u8 &img, const std::string &location);
|
||||
struct clip_image_size;
|
||||
static int bmp_export(const struct clip_image_u8 &img, const std::string &location);
|
||||
#endif
|
||||
|
||||
struct clip_image_size {
|
||||
|
@ -53,13 +54,13 @@ struct clip_image_f32 {
|
|||
using clip_image_f32_batch = std::vector<clip_image_f32>;
|
||||
using clip_image_f8_batch = std::vector<clip_image_u8>;
|
||||
|
||||
static int clip_n_patches(const clip_context & ctx) {
|
||||
int clip_n_patches(const clip_context & ctx) {
|
||||
auto & hparams = ctx.model->hparams;
|
||||
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
|
||||
return n_patches;
|
||||
}
|
||||
|
||||
static int clip_n_mmproj_embd(const clip_context & ctx) {
|
||||
int clip_n_mmproj_embd(const clip_context & ctx) {
|
||||
if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
return ctx.model->mm_b_b->ne[0];
|
||||
} else {
|
||||
|
@ -67,7 +68,7 @@ static int clip_n_mmproj_embd(const clip_context & ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
static int clip_n_embd(const clip_context & ctx) {
|
||||
int clip_n_embd(const clip_context & ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx);
|
||||
}
|
||||
|
||||
|
@ -323,7 +324,7 @@ static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8
|
|||
|
||||
const int nx = temp.nx;
|
||||
const int ny = temp.ny;
|
||||
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
|
||||
// bmp_export(temp, "resized_vanilla.bmp");
|
||||
|
||||
const int nx2 = params.image_size;
|
||||
const int ny2 = params.image_size;
|
||||
|
@ -451,11 +452,11 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
|||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w);
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b);
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < (int)hparams.n_layer - 1; il++) {
|
||||
for (int il = 0; il < (int)hparams.n_layer - 2; il++) {
|
||||
struct ggml_tensor * cur = embeddings;
|
||||
|
||||
// layernorm1
|
||||
|
@ -537,6 +538,14 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
|||
embeddings = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (model.post_norm_w) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b);
|
||||
}
|
||||
|
||||
// llava projector
|
||||
{
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||
|
@ -673,6 +682,7 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s
|
|||
clip_image_u8 img_u8(img);
|
||||
clip_image_f32_batch img_res_v;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
// bmp_export(img_u8, "test_inp.bmp");
|
||||
|
||||
if (!clip_image_preprocess(ctx, img_u8, img_res_v)) {
|
||||
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
|
||||
|
@ -724,7 +734,6 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
|
|||
// copy output embeddings to result
|
||||
for (int k = 0; k < n_embd; k++) {
|
||||
ctx.output[n_embd*i + k] = output_single[k];
|
||||
// if (k<10) printf("%f\n", output_single[k]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -735,10 +744,19 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
|
|||
// for debugging
|
||||
#ifndef NDEBUG
|
||||
|
||||
static int bmp_export(const clip_image_u8 &img, const std::string &location) {
|
||||
static int bmp_export(const struct clip_image_u8 &img, const std::string &location) {
|
||||
const uint32_t width = img.nx;
|
||||
const uint32_t height = img.ny;
|
||||
const std::vector<uint8_t> &buffer = img.buf;
|
||||
// swap red and blue channel
|
||||
std::vector<uint8_t> buffer(width*height*3);
|
||||
for (uint32_t y = 0; y < height; y++) {
|
||||
for (uint32_t x = 0; x < width; x++) {
|
||||
size_t base = x*3 + y*3*width;
|
||||
buffer[base+2] = img.buf[base];
|
||||
buffer[base+1] = img.buf[base+1];
|
||||
buffer[base] = img.buf[base+2];
|
||||
}
|
||||
}
|
||||
const bool hasAlphaChannel = false;
|
||||
|
||||
std::ofstream fout(location, std::ios::out | std::ios::binary);
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
#include <array>
|
||||
|
||||
enum vision_arch {
|
||||
VISION_ARCH_LLAVA,
|
||||
VISION_ARCH_UNKNOWN,
|
||||
VISION_ARCH_LLAVA,
|
||||
};
|
||||
|
||||
enum clip_projector_type {
|
||||
|
@ -112,4 +112,8 @@ struct clip_context {
|
|||
std::vector<float> output; // size == n_output * n_embd
|
||||
};
|
||||
|
||||
int clip_n_patches(const clip_context & ctx);
|
||||
int clip_n_mmproj_embd(const clip_context & ctx);
|
||||
int clip_n_embd(const clip_context & ctx);
|
||||
|
||||
int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch);
|
||||
|
|
|
@ -6239,7 +6239,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
|
||||
if (proj_type == "mlp") {
|
||||
vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP;
|
||||
} else {
|
||||
|
@ -8987,9 +8987,9 @@ static bool llm_load_tensors(
|
|||
model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd});
|
||||
|
||||
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd});
|
||||
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
|
||||
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
|
||||
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
|
||||
model.clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
|
||||
// model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
|
||||
// model.clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
|
@ -21815,6 +21815,10 @@ float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) {
|
|||
return ctx->clip.output.data();
|
||||
}
|
||||
|
||||
int32_t llama_vision_n_patches(struct llama_context * ctx) {
|
||||
return clip_n_patches(ctx->clip);
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue