simple example works

This commit is contained in:
Xuan Son Nguyen 2024-10-01 22:53:30 +02:00
parent 4897ff61c6
commit 6089b0a50a
7 changed files with 111 additions and 33 deletions

View file

@ -32,6 +32,7 @@ llama_img * load_image_from_file(const char * fname) {
// } // }
// printf("\n"); // printf("\n");
llama_img * result = llama_img_alloc(nx, ny); llama_img * result = llama_img_alloc(nx, ny);
memcpy(result->data, bytes, nx*ny*nc); memcpy(result->data, img, nx*ny*3);
stbi_image_free(img);
return result; return result;
} }

View file

@ -11,6 +11,7 @@ import json
import os import os
import re import re
import sys import sys
from transformers import AutoConfig
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from hashlib import sha256 from hashlib import sha256
@ -67,6 +68,7 @@ class Model:
is_lora: bool is_lora: bool
# for vision model # for vision model
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap v_tensor_map: gguf.TensorNameMap
v_tensor_names: set[str] | None v_tensor_names: set[str] | None
@ -100,6 +102,7 @@ class Model:
self.model_name = model_name self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED: if self.ftype == gguf.LlamaFileType.GUESSED:
@ -463,8 +466,20 @@ class Model:
with open(dir_model / "config.json", "r", encoding="utf-8") as f: with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if "text_config" in hparams: if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]} text_config = hparams["text_config"]
if "_name_or_path" in text_config:
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams return hparams
@staticmethod
def load_preprocessor_config(dir_model: Path):
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
return None
@classmethod @classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -1574,7 +1589,7 @@ class LlamaModel(Model):
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)
# For vision model # For vision model
if self.vparams is not None: if self.vparams is not None and self.preprocessor_config is not None:
self.gguf_writer.add_vision_type("clip") self.gguf_writer.add_vision_type("clip")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
@ -1583,14 +1598,13 @@ class LlamaModel(Model):
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
# TODO: should not hardcode these, but they are currently missing from config.json # TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
self.gguf_writer.add_vision_clip_max_position_embeddings(577)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
self.gguf_writer.add_vision_clip_image_mean(default_image_mean)
self.gguf_writer.add_vision_clip_image_std(default_image_std)
@staticmethod @staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
@ -1606,8 +1620,11 @@ class LlamaModel(Model):
n_head = self.hparams["num_attention_heads"] n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
# For vision model
if name.startswith("language_model"): if name.startswith("language_model"):
name = name.replace("language_model.", "") name = name.replace("language_model.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm
if name.endswith(("q_proj.weight", "q_proj.bias")): if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head) data_torch = LlamaModel.permute(data_torch, n_head, n_head)

View file

@ -66,12 +66,41 @@ int main(int argc, char ** argv) {
// TODO: this is for testing; DELETE ME // TODO: this is for testing; DELETE ME
llama_img_batch ibatch; int n_cur = 0;
ibatch.n_imgs = 1; params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
ibatch.imgs = (llama_img **) malloc(1024); {
ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png"); llama_img_batch ibatch;
llama_vision_encode(ctx, &ibatch); ibatch.n_imgs = 1;
return 0; ibatch.imgs = (llama_img **) malloc(1024);
ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");
llama_vision_encode(ctx, &ibatch);
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
int n_imgs = ibatch.n_imgs;
int n_embd = llama_n_embd(model);
int n_patches = llama_vision_n_patches(ctx);
printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches);
float * output_img = llama_vision_get_embeddings(ctx, 0);
n_cur += tokens.size();
llama_batch batch = llama_batch_init(512, 0, 1);
llama_batch_clear(batch);
for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
// for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]);
llama_batch_clear(batch);
batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
n_cur += n_embd*n_imgs;
}
params.prompt = "\nwhat did you see?\nASSISTANT:";
@ -108,7 +137,10 @@ int main(int argc, char ** argv) {
// evaluate the initial prompt // evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) { for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false); //llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
if (i == 0) continue;
llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false);
n_cur++;
} }
// llama_decode will output logits only for the last token of the prompt // llama_decode will output logits only for the last token of the prompt
@ -121,18 +153,18 @@ int main(int argc, char ** argv) {
// main loop // main loop
int n_cur = batch.n_tokens; //int n_cur = batch.n_tokens;
int n_decode = 0; int n_decode = 0;
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
while (n_cur <= n_predict) { for (int i = 0; i < n_predict; i++) {
// sample the next token // sample the next token
{ {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(model, new_token_id)) {
LOG("\n"); LOG("\n");
break; break;

View file

@ -903,6 +903,8 @@ extern "C" {
// get output embeddings, to be put into language batch // get output embeddings, to be put into language batch
LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx); LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx);
LLAMA_API int32_t llama_vision_n_patches(struct llama_context * ctx);
// //
// Vocab // Vocab
// //

View file

@ -14,7 +14,8 @@
// export clip_image_u8 to bmp file for debugging // export clip_image_u8 to bmp file for debugging
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c // https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
static int bmp_export(const clip_image_u8 &img, const std::string &location); struct clip_image_size;
static int bmp_export(const struct clip_image_u8 &img, const std::string &location);
#endif #endif
struct clip_image_size { struct clip_image_size {
@ -53,13 +54,13 @@ struct clip_image_f32 {
using clip_image_f32_batch = std::vector<clip_image_f32>; using clip_image_f32_batch = std::vector<clip_image_f32>;
using clip_image_f8_batch = std::vector<clip_image_u8>; using clip_image_f8_batch = std::vector<clip_image_u8>;
static int clip_n_patches(const clip_context & ctx) { int clip_n_patches(const clip_context & ctx) {
auto & hparams = ctx.model->hparams; auto & hparams = ctx.model->hparams;
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
return n_patches; return n_patches;
} }
static int clip_n_mmproj_embd(const clip_context & ctx) { int clip_n_mmproj_embd(const clip_context & ctx) {
if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
return ctx.model->mm_b_b->ne[0]; return ctx.model->mm_b_b->ne[0];
} else { } else {
@ -67,7 +68,7 @@ static int clip_n_mmproj_embd(const clip_context & ctx) {
} }
} }
static int clip_n_embd(const clip_context & ctx) { int clip_n_embd(const clip_context & ctx) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx); return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx);
} }
@ -323,7 +324,7 @@ static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8
const int nx = temp.nx; const int nx = temp.nx;
const int ny = temp.ny; const int ny = temp.ny;
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); // bmp_export(temp, "resized_vanilla.bmp");
const int nx2 = params.image_size; const int nx2 = params.image_size;
const int ny2 = params.image_size; const int ny2 = params.image_size;
@ -451,11 +452,11 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln"); ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b);
} }
// loop over layers // loop over layers
for (int il = 0; il < (int)hparams.n_layer - 1; il++) { for (int il = 0; il < (int)hparams.n_layer - 2; il++) {
struct ggml_tensor * cur = embeddings; struct ggml_tensor * cur = embeddings;
// layernorm1 // layernorm1
@ -537,6 +538,14 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
embeddings = cur; embeddings = cur;
} }
// post-layernorm
if (model.post_norm_w) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b);
}
// llava projector // llava projector
{ {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@ -673,6 +682,7 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s
clip_image_u8 img_u8(img); clip_image_u8 img_u8(img);
clip_image_f32_batch img_res_v; clip_image_f32_batch img_res_v;
auto & hparams = ctx.model->hparams; auto & hparams = ctx.model->hparams;
// bmp_export(img_u8, "test_inp.bmp");
if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { if (!clip_image_preprocess(ctx, img_u8, img_res_v)) {
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
@ -724,7 +734,6 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
// copy output embeddings to result // copy output embeddings to result
for (int k = 0; k < n_embd; k++) { for (int k = 0; k < n_embd; k++) {
ctx.output[n_embd*i + k] = output_single[k]; ctx.output[n_embd*i + k] = output_single[k];
// if (k<10) printf("%f\n", output_single[k]);
} }
} }
@ -735,10 +744,19 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch
// for debugging // for debugging
#ifndef NDEBUG #ifndef NDEBUG
static int bmp_export(const clip_image_u8 &img, const std::string &location) { static int bmp_export(const struct clip_image_u8 &img, const std::string &location) {
const uint32_t width = img.nx; const uint32_t width = img.nx;
const uint32_t height = img.ny; const uint32_t height = img.ny;
const std::vector<uint8_t> &buffer = img.buf; // swap red and blue channel
std::vector<uint8_t> buffer(width*height*3);
for (uint32_t y = 0; y < height; y++) {
for (uint32_t x = 0; x < width; x++) {
size_t base = x*3 + y*3*width;
buffer[base+2] = img.buf[base];
buffer[base+1] = img.buf[base+1];
buffer[base] = img.buf[base+2];
}
}
const bool hasAlphaChannel = false; const bool hasAlphaChannel = false;
std::ofstream fout(location, std::ios::out | std::ios::binary); std::ofstream fout(location, std::ios::out | std::ios::binary);

View file

@ -6,8 +6,8 @@
#include <array> #include <array>
enum vision_arch { enum vision_arch {
VISION_ARCH_LLAVA,
VISION_ARCH_UNKNOWN, VISION_ARCH_UNKNOWN,
VISION_ARCH_LLAVA,
}; };
enum clip_projector_type { enum clip_projector_type {
@ -112,4 +112,8 @@ struct clip_context {
std::vector<float> output; // size == n_output * n_embd std::vector<float> output; // size == n_output * n_embd
}; };
int clip_n_patches(const clip_context & ctx);
int clip_n_mmproj_embd(const clip_context & ctx);
int clip_n_embd(const clip_context & ctx);
int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch); int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch);

View file

@ -6239,7 +6239,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true);
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
if (proj_type == "mlp") { if (proj_type == "mlp") {
vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP; vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP;
} else { } else {
@ -8987,9 +8987,9 @@ static bool llm_load_tensors(
model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd});
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd});
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); model.clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}); // model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}); // model.clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
@ -21815,6 +21815,10 @@ float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) {
return ctx->clip.output.data(); return ctx->clip.output.data();
} }
int32_t llama_vision_n_patches(struct llama_context * ctx) {
return clip_n_patches(ctx->clip);
}
// //
// model split // model split
// //