From 6089b0a50a1e734b7ccb57b8b187fcac0f34e0aa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 1 Oct 2024 22:53:30 +0200 Subject: [PATCH] simple example works --- common/vision.cpp | 3 ++- convert_hf_to_gguf.py | 31 ++++++++++++++++++----- examples/simple/simple.cpp | 52 ++++++++++++++++++++++++++++++-------- include/llama.h | 2 ++ src/llama-vision.cpp | 38 ++++++++++++++++++++-------- src/llama-vision.h | 6 ++++- src/llama.cpp | 12 ++++++--- 7 files changed, 111 insertions(+), 33 deletions(-) diff --git a/common/vision.cpp b/common/vision.cpp index 7b5c1d995..5b003654a 100644 --- a/common/vision.cpp +++ b/common/vision.cpp @@ -32,6 +32,7 @@ llama_img * load_image_from_file(const char * fname) { // } // printf("\n"); llama_img * result = llama_img_alloc(nx, ny); - memcpy(result->data, bytes, nx*ny*nc); + memcpy(result->data, img, nx*ny*3); + stbi_image_free(img); return result; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0340c138a..1c8f912a9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11,6 +11,7 @@ import json import os import re import sys +from transformers import AutoConfig from enum import IntEnum from pathlib import Path from hashlib import sha256 @@ -67,6 +68,7 @@ class Model: is_lora: bool # for vision model + preprocessor_config: dict[str, Any] | None = None vparams: dict[str, Any] | None = None v_tensor_map: gguf.TensorNameMap v_tensor_names: set[str] | None @@ -100,6 +102,7 @@ class Model: self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -463,8 +466,20 @@ class Model: with open(dir_model / "config.json", "r", encoding="utf-8") as f: hparams = json.load(f) if "text_config" in hparams: - hparams = {**hparams, **hparams["text_config"]} + text_config = hparams["text_config"] + if "_name_or_path" in text_config: + text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() + hparams = {**text_config, **hparams} return hparams + + @staticmethod + def load_preprocessor_config(dir_model: Path): + file_path = dir_model / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + return None @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1574,7 +1589,7 @@ class LlamaModel(Model): self.gguf_writer.add_add_bos_token(False) # For vision model - if self.vparams is not None: + if self.vparams is not None and self.preprocessor_config is not None: self.gguf_writer.add_vision_type("clip") self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) @@ -1583,14 +1598,13 @@ class LlamaModel(Model): self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) + self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"]) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd) # TODO: should not hardcode these, but they are currently missing from config.json self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) - self.gguf_writer.add_vision_clip_max_position_embeddings(577) self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) - default_image_mean = [0.48145466, 0.4578275, 0.40821073] - default_image_std = [0.26862954, 0.26130258, 0.27577711] - self.gguf_writer.add_vision_clip_image_mean(default_image_mean) - self.gguf_writer.add_vision_clip_image_std(default_image_std) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): @@ -1606,8 +1620,11 @@ class LlamaModel(Model): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + # For vision model if name.startswith("language_model"): name = name.replace("language_model.", "") + if "post_layernorm" in name: + return [] # skip post_layernorm if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 303d31e05..50f2ff4ea 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -66,12 +66,41 @@ int main(int argc, char ** argv) { // TODO: this is for testing; DELETE ME - llama_img_batch ibatch; - ibatch.n_imgs = 1; - ibatch.imgs = (llama_img **) malloc(1024); - ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png"); - llama_vision_encode(ctx, &ibatch); - return 0; + int n_cur = 0; + params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:"; + { + llama_img_batch ibatch; + ibatch.n_imgs = 1; + ibatch.imgs = (llama_img **) malloc(1024); + ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg"); + llama_vision_encode(ctx, &ibatch); + + auto tokens = ::llama_tokenize(ctx, params.prompt, true); + int n_imgs = ibatch.n_imgs; + int n_embd = llama_n_embd(model); + int n_patches = llama_vision_n_patches(ctx); + printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches); + float * output_img = llama_vision_get_embeddings(ctx, 0); + + n_cur += tokens.size(); + llama_batch batch = llama_batch_init(512, 0, 1); + llama_batch_clear(batch); + for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; } + if (llama_decode(ctx, batch) != 0) { + LOG("%s: llama_decode() failed\n", __func__); + return 1; + } + + // for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]); + llama_batch_clear(batch); + batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, }; + if (llama_decode(ctx, batch) != 0) { + LOG("%s: llama_decode() failed\n", __func__); + return 1; + } + n_cur += n_embd*n_imgs; + } + params.prompt = "\nwhat did you see?\nASSISTANT:"; @@ -108,7 +137,10 @@ int main(int argc, char ** argv) { // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); i++) { - llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + //llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + if (i == 0) continue; + llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false); + n_cur++; } // llama_decode will output logits only for the last token of the prompt @@ -121,18 +153,18 @@ int main(int argc, char ** argv) { // main loop - int n_cur = batch.n_tokens; + //int n_cur = batch.n_tokens; int n_decode = 0; const auto t_main_start = ggml_time_us(); - while (n_cur <= n_predict) { + for (int i = 0; i < n_predict; i++) { // sample the next token { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { + if (llama_token_is_eog(model, new_token_id)) { LOG("\n"); break; diff --git a/include/llama.h b/include/llama.h index 49a32d66b..ed43796b1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -903,6 +903,8 @@ extern "C" { // get output embeddings, to be put into language batch LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx); + LLAMA_API int32_t llama_vision_n_patches(struct llama_context * ctx); + // // Vocab // diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index ff6dea4f4..dab3b999a 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -14,7 +14,8 @@ // export clip_image_u8 to bmp file for debugging // https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c -static int bmp_export(const clip_image_u8 &img, const std::string &location); +struct clip_image_size; +static int bmp_export(const struct clip_image_u8 &img, const std::string &location); #endif struct clip_image_size { @@ -53,13 +54,13 @@ struct clip_image_f32 { using clip_image_f32_batch = std::vector; using clip_image_f8_batch = std::vector; -static int clip_n_patches(const clip_context & ctx) { +int clip_n_patches(const clip_context & ctx) { auto & hparams = ctx.model->hparams; int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); return n_patches; } -static int clip_n_mmproj_embd(const clip_context & ctx) { +int clip_n_mmproj_embd(const clip_context & ctx) { if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { return ctx.model->mm_b_b->ne[0]; } else { @@ -67,7 +68,7 @@ static int clip_n_mmproj_embd(const clip_context & ctx) { } } -static int clip_n_embd(const clip_context & ctx) { +int clip_n_embd(const clip_context & ctx) { return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx); } @@ -323,7 +324,7 @@ static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 const int nx = temp.nx; const int ny = temp.ny; - // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); + // bmp_export(temp, "resized_vanilla.bmp"); const int nx2 = params.image_size; const int ny2 = params.image_size; @@ -451,11 +452,11 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "pre_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b); } // loop over layers - for (int il = 0; il < (int)hparams.n_layer - 1; il++) { + for (int il = 0; il < (int)hparams.n_layer - 2; il++) { struct ggml_tensor * cur = embeddings; // layernorm1 @@ -537,6 +538,14 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, embeddings = cur; } + // post-layernorm + if (model.post_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b); + } + // llava projector { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -673,6 +682,7 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s clip_image_u8 img_u8(img); clip_image_f32_batch img_res_v; auto & hparams = ctx.model->hparams; + // bmp_export(img_u8, "test_inp.bmp"); if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); @@ -724,7 +734,6 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch // copy output embeddings to result for (int k = 0; k < n_embd; k++) { ctx.output[n_embd*i + k] = output_single[k]; - // if (k<10) printf("%f\n", output_single[k]); } } @@ -735,10 +744,19 @@ int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch // for debugging #ifndef NDEBUG -static int bmp_export(const clip_image_u8 &img, const std::string &location) { +static int bmp_export(const struct clip_image_u8 &img, const std::string &location) { const uint32_t width = img.nx; const uint32_t height = img.ny; - const std::vector &buffer = img.buf; + // swap red and blue channel + std::vector buffer(width*height*3); + for (uint32_t y = 0; y < height; y++) { + for (uint32_t x = 0; x < width; x++) { + size_t base = x*3 + y*3*width; + buffer[base+2] = img.buf[base]; + buffer[base+1] = img.buf[base+1]; + buffer[base] = img.buf[base+2]; + } + } const bool hasAlphaChannel = false; std::ofstream fout(location, std::ios::out | std::ios::binary); diff --git a/src/llama-vision.h b/src/llama-vision.h index c14c880c4..dfcab10a5 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -6,8 +6,8 @@ #include enum vision_arch { - VISION_ARCH_LLAVA, VISION_ARCH_UNKNOWN, + VISION_ARCH_LLAVA, }; enum clip_projector_type { @@ -112,4 +112,8 @@ struct clip_context { std::vector output; // size == n_output * n_embd }; +int clip_n_patches(const clip_context & ctx); +int clip_n_mmproj_embd(const clip_context & ctx); +int clip_n_embd(const clip_context & ctx); + int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch); diff --git a/src/llama.cpp b/src/llama.cpp index b9d64764f..08b1aa17e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6239,7 +6239,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); - ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); + ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); if (proj_type == "mlp") { vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP; } else { @@ -8987,9 +8987,9 @@ static bool llm_load_tensors( model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); - model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); - model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}); - model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}); + model.clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); + // model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}); + // model.clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -21815,6 +21815,10 @@ float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) { return ctx->clip.output.data(); } +int32_t llama_vision_n_patches(struct llama_context * ctx) { + return clip_n_patches(ctx->clip); +} + // // model split //