From c13edfed59928294d3cd2f6bdf56bbdee0ba3997 Mon Sep 17 00:00:00 2001 From: HimariO Date: Thu, 10 Oct 2024 22:30:42 +0800 Subject: [PATCH] [WIP] qwen2vl vision model --- examples/llava/clip.cpp | 221 ++++++++++++++++++++++--- examples/llava/clip.h | 6 +- examples/llava/qwen2_vl_surgery.py | 254 ++++++++++++++--------------- examples/llava/qwen2vl-cli.cpp | 29 +++- gguf-py/gguf/constants.py | 16 ++ 5 files changed, 365 insertions(+), 161 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 3561c6951..fe744a6c1 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -98,7 +98,9 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" +#define KEY_USE_SILU "clip.use_silu" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_BLOCK "clip.%s.block_count" @@ -125,7 +127,8 @@ static std::string format(const char * fmt, ...) { #define TN_TOKEN_EMBD "%s.token_embd.weight" #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" -#define TN_PATCH_EMBD "v.patch_embd.weight" +#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat +#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" @@ -159,6 +162,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_UNKNOWN, }; @@ -167,6 +171,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, }; @@ -459,8 +464,8 @@ struct clip_vision_model { // embeddings struct ggml_tensor * class_embedding; - struct ggml_tensor * patch_embeddings; - struct ggml_tensor * patch_embeddings_t1; // second kernel of Conv3D when we decouple along temproal dimension + struct ggml_tensor * patch_embeddings_0; + struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) struct ggml_tensor * patch_bias; struct ggml_tensor * position_embeddings; @@ -550,6 +555,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_qwen2vl_merger = false; int minicpmv_version = 2; struct clip_vision_model vision_model; @@ -558,6 +564,7 @@ struct clip_ctx { float image_mean[3]; float image_std[3]; bool use_gelu = false; + bool use_silu = false; int32_t ftype = 1; bool has_class_embedding = true; @@ -603,14 +610,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 image_size_height = imgs->data->ny; } } + else if (ctx->has_qwen2vl_merger) { + // use the image's native resolution when image is avaible + if (is_inf) { + // if (imgs->data->nx && imgs->data->ny) { + image_size_width = imgs->data->nx; + image_size_height = imgs->data->ny; + } + } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int patches_w = image_size_width / patch_size; + const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 3 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; int n_layer = hparams.n_layer; const float eps = hparams.eps; + int mrope_sections[3] = {d_head/4, d_head/4, 0}; const int batch_size = imgs->size; @@ -631,10 +650,34 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + if (ctx->has_qwen2vl_merger) { + GGML_ASSERT(image_size_width % (patch_size * 2) == 0); + GGML_ASSERT(image_size_height % (patch_size * 2) == 0); + + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, patches_h, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); + // inp = ggml_reshape_2d( + // ctx0, inp, + // hidden_size * 4, (patches_w / 2) * batch_size * (patches_h / 2)); + inp = ggml_reshape_3d( + ctx0, inp, + hidden_size, patches_w * patches_h, batch_size); + + } + else { + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + } if (ctx->has_patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); @@ -656,12 +699,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); ggml_set_name(positions, "positions"); ggml_set_input(positions); - embeddings = - ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding + embeddings = + ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + } if (ctx->has_minicpmv_projector) { int pos_w = image_size_width/patch_size; @@ -707,8 +752,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + if (ctx->has_qwen2vl_merger) { + Q = ggml_mrope_ext( + ctx0, Q, positions, nullptr, + d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 1000000, 1, 0, 1, 32, 1); + } + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); @@ -716,6 +766,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + if (ctx->has_qwen2vl_merger) { + K = ggml_mrope_ext( + ctx0, K, positions, nullptr, + d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 1000000, 1, 0, 1, 32, 1); + } K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); @@ -755,6 +810,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (ctx->use_gelu) { cur = ggml_gelu_inplace(ctx0, cur); + } else if (ctx->use_silu) { + cur = ggml_silu_inplace(ctx0, cur); } else { cur = ggml_gelu_quick_inplace(ctx0, cur); } @@ -1027,6 +1084,29 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } + else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + // // First LayerNorm + // embeddings = ggml_norm(ctx0, embeddings, eps); + // embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), + // model.mm_1_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + // // Second LayerNorm + // embeddings = ggml_norm(ctx0, embeddings, eps); + // embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), + // model.mm_4_b); + } // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1198,6 +1278,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); + if (idx != -1) { + new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); + } // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search GGML_ASSERT(new_clip->has_vision_encoder); @@ -1205,6 +1289,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { idx = get_key_idx(ctx, KEY_USE_GELU); new_clip->use_gelu = gguf_get_val_bool(ctx, idx); + + idx = get_key_idx(ctx, KEY_USE_SILU); + new_clip->use_silu = gguf_get_val_bool(ctx, idx); if (verbosity >= 1) { LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); @@ -1381,11 +1468,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } try { - vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); + vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); } catch(const std::exception& /*e*/) { LOG_ERR("%s: failed to load vision model tensors\n", __func__); } + try { + vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1); + } catch(const std::exception& /*e*/) { + new_clip->has_qwen2vl_merger = false; + } // LLaVA projection if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { @@ -1473,6 +1565,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { + vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -1511,6 +1609,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); clip_image_f32_batch batch; batch.size = 1; + batch.data = nullptr; ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); @@ -1975,6 +2074,22 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } return true; + } + else if (ctx->has_qwen2vl_merger) { + clip_image_u8 * resized = clip_image_u8_init(); + auto patch_size = clip_patch_size(ctx) * 2; + int nx = ceil((float)img->nx / patch_size) * patch_size; + int ny = ceil((float)img->ny / patch_size) * patch_size; + bicubic_resize(*img, *resized, nx, ny); + + res_imgs->data = new clip_image_f32[1]; + // clip_image_f32 * res = clip_image_f32_init(); + normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std); + // res_imgs->data[0] = *res; + + // clip_image_f32_free(res); + clip_image_u8_free(resized); + return true; } bool pad_to_square = true; @@ -2186,6 +2301,13 @@ const int32_t * clip_image_grid(const struct clip_ctx * ctx) { } int clip_n_patches(const struct clip_ctx * ctx) { + clip_image_f32 img; + img.nx = ctx->vision_model.hparams.image_size; + img.ny = ctx->vision_model.hparams.image_size; + return clip_n_patches_by_img(ctx, &img); +} + +int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->vision_model.hparams; int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); @@ -2199,6 +2321,11 @@ int clip_n_patches(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { n_patches = 64; } + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + int patch_size = params.patch_size * 2; + int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); + int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); + n_patches = x_patch * y_patch; } return n_patches; @@ -2327,13 +2454,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int image_size = hparams.image_size; int image_size_width = image_size; int image_size_height = image_size; - if (ctx->has_minicpmv_projector) { + if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { image_size_width = imgs->data[0].nx; image_size_height = imgs->data[0].ny; } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 3 : num_positions; if(ctx->load_image_size==nullptr){ ctx->load_image_size= clip_image_size_init(); } @@ -2347,7 +2475,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (size_t i = 0; i < imgs->size; i++) { const int nx = imgs->data[i].nx; const int ny = imgs->data[i].ny; - if (!ctx->has_minicpmv_projector) { + if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { GGML_ASSERT(nx == image_size && ny == image_size); } @@ -2427,7 +2555,41 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - { + if (ctx->has_qwen2vl_merger) { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + // for (size_t y = 0; y < ph; y++) + // { + // for (size_t x = 0; x < pw; x++) + // { + // positions_data[y * pw + x] = y; + // positions_data[num_patches + (y * pw + x)] = x; + // positions_data[num_patches * 2 + (y * pw + x)] = 0; + // } + // } + + int ptr = -1; + for (size_t y = 0; y < ph; y+=2) + { + for (size_t x = 0; x < pw; x+=2) + { + for (size_t dy = 0; y < 2; y++) { + for (size_t dx = 0; x < 2; x++) { + positions_data[ptr++] = y + dy; + positions_data[ptr++] = x + dx; + positions_data[ptr++] = 0; + } + } + } + } + + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + else { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); @@ -2436,16 +2598,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - } - { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - int* patches_data = (int*)malloc(ggml_nbytes(patches)); - for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); } } @@ -2629,3 +2791,18 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { } return 0; } + + +bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { + clip_image_f32 clip_img; + clip_img.buf.resize(h * w * 3); + for (size_t i = 0; i < h*w*3; i++) + { + clip_img.buf[i] = img[i]; + } + clip_img.nx = w; + clip_img.ny = h; + // ctx->vision_model.hparams.image_size = h; + clip_image_encode(ctx, n_threads, &clip_img, vec); + return true; +} \ No newline at end of file diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 78588bdf1..6d992d901 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -55,8 +55,9 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); -CLIP_API int clip_n_patches (const struct clip_ctx * ctx); -CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); +CLIP_API int clip_n_patches (const struct clip_ctx * ctx); +CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); +CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx); CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); @@ -87,6 +88,7 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); #ifdef __cplusplus } #endif diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 2d5b32fe6..7b8c4fd2d 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -1,159 +1,141 @@ import argparse import glob import os +from typing import Any, Dict + import torch -from safetensors import safe_open -from safetensors.torch import save_file -from typing import Any, ContextManager, cast - -# Function to determine if file is a SafeTensor file -def is_safetensor_file(file_path): - return file_path.endswith('.safetensors') +from gguf import * +from transformers import ( + Qwen2VLForConditionalGeneration, + Qwen2VLProcessor, + AutoProcessor, + Qwen2VLConfig +) -# Unified loading function -def load_model(file_path): - if is_safetensor_file(file_path): - tensors = {} - with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f: - for key in f.keys(): - tensors[key] = f.get_tensor(key).clone() - # output shape - print(f"{key} : {tensors[key].shape}") - return tensors, 'safetensor' - else: - return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch' +VISION = "clip.vision" -# Unified saving function -def save_model(model, file_path, file_type): - if file_type == 'safetensor': - # safe_save(model, file_path) - save_file(model, file_path) - else: - torch.save(model, file_path) +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) -# Adapted function to clean vision tower from checkpoint -def clean_vision_tower_from_checkpoint(checkpoint_path): - checkpoint, file_type = load_model(checkpoint_path) - # file_type = 'pytorch' - model_path = os.path.dirname(checkpoint_path) - print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] +def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") + # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[to_gguf_name] {og} --> {name}") + return name - if len(clip_tensors) > 0: - print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") - # Adapted for file type - clip_path = os.path.join(model_path, "llava.clip") - if os.path.exists(clip_path): - print(f"Loading existing llava.clip from {clip_path}") - existing_clip, _ = load_model(clip_path) +def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: + vision_model = qwen2vl.visual + tensor_map = {} + for name, ten in vision_model.state_dict().items(): + ten = ten.numpy().astype(dtype) + if 'qkv' in name: + if ten.ndim == 2: # weight + c3, _ = ten.shape + else: # bias + c3 = ten.shape[0] + c = c3//3 + wq = ten[:c] + wk = ten[c: c * 2] + wv = ten[c * 2:] + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv + # breakpoint() + elif 'merger' in name: + if name.endswith("ln_q.weight"): + tensor_map['v.post_ln.weight'] = ten + elif name.endswith("ln_q.bias"): + tensor_map['v.post_ln.bias'] = ten + else: + # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" + tensor_map[to_gguf_name(name)] = ten + elif 'patch_embed.proj.weight' in name: + # NOTE: split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = ten.shape + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] + tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] else: - print(f"Creating new llava.clip at {clip_path}") - existing_clip = {} - # Update existing_clip with new tensors, avoid duplicates - for name in clip_tensors: - simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name - print(f"Adding {simple_name} to llava.clip") - if simple_name not in existing_clip: - existing_clip[simple_name] = checkpoint[name] - - # Save the updated clip tensors back to llava.clip - save_model(existing_clip, clip_path, 'pytorch') - - # Remove the tensors from the original checkpoint - for name in clip_tensors: - del checkpoint[name] - - checkpoint_path = checkpoint_path - return True - return False - -def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): - newline_checkpoint_path = None - projector_checkpoint_path = None - - for path in checkpoint_paths: - checkpoint, _ = load_model(path) - if newline_criteria(checkpoint) and newline_checkpoint_path is None: - newline_checkpoint_path = path - if projector(checkpoint): - projector_checkpoint_path = path - - return newline_checkpoint_path, projector_checkpoint_path - -def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) - -def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + tensor_map[to_gguf_name(f"vision_model.{name}")] = ten + + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=dtype) # dummy tensor, just here as a placeholder + return tensor_map -# Command-line interface setup -ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model") -ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files") -args = ap.parse_args() +def main(data_type='fp32'): + if data_type == 'fp32': + dtype = torch.float32 + np_dtype = np.float32 + ftype = 0 + elif data_type == 'fp16': + dtype = torch.float16 + np_dtype = np.float16 + ftype = 1 + else: + raise ValueError() + + model_name = "Qwen/Qwen2-VL-2B-Instruct" + qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2VLConfig = qwen2vl.config + vcfg = cfg.vision_config + rope_cfg = cfg.rope_scaling -if args.clean_vision_tower: - # Generalized to handle both PyTorch and SafeTensors models - model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) - # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))] - checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] - for projector_checkpoint_path in checkpoint_paths: - print(f"Cleaning {projector_checkpoint_path}") - if not clean_vision_tower_from_checkpoint(projector_checkpoint_path): - print(f"No vision tower found in {projector_checkpoint_path}") - # we break once none is found, so far all models append them at the end - # break - print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.") -# Now we look for the projector in the last checkpoint -model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) -checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] -# last_checkpoint_path = checkpoint_paths[0] -# first_checkpoint_path = checkpoint_paths[-1] -newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria) + fname_out = "qwen2vl-vision.gguf" + fout = GGUFWriter(path=fname_out, arch="clip") + fout.add_description("image encoder for Qwen2VL") -print(f"Taking projector from {projector_checkpoint_path}") -first_mm_tensors = [] -first_checkpoint = None -if newline_checkpoint_path is not None: - print(f"Taking newline from {newline_checkpoint_path}") - first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + fout.add_file_type(ftype) + fout.add_bool("clip.has_text_encoder", False) + fout.add_bool("clip.has_vision_encoder", True) + fout.add_bool("clip.has_qwen2vl_merger", True) + fout.add_string("clip.projector_type", "qwen2vl_merger") -# Load the checkpoint -mm_tensors = [] -last_checkpoint = None -if projector_checkpoint_path is not None: - last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + if 'silu' in cfg.vision_config.hidden_act.lower(): + fout.add_bool("clip.use_silu", True) + fout.add_bool("clip.use_gelu", False) + elif 'gelu' in cfg.vision_config.hidden_act.lower(): + fout.add_bool("clip.use_silu", False) + fout.add_bool("clip.use_gelu", True) + else: + raise ValueError() -if len(mm_tensors) == 0: - if last_checkpoint is not None: - for k, v in last_checkpoint.items(): - print(k) - print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.") - print("No tensors found. Is this a LLaVA model?") - exit() + tensor_map = find_vision_tensors(qwen2vl, np_dtype) + for name, data in tensor_map.items(): + fout.add_tensor(name, data) -print(f"Found {len(mm_tensors)} tensors to extract.") -print(f"Found additional {len(first_mm_tensors)} tensors to extract.") -# projector = {name: checkpoint.[name].float() for name in mm_tensors} -projector = {} -for name in mm_tensors: - assert last_checkpoint is not None - projector[name] = last_checkpoint[name].float() -for name in first_mm_tensors: - assert first_checkpoint is not None - projector[name] = first_checkpoint[name].float() + fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) + fout.add_uint32("clip.vision.image_size", 14*40) # some reasonable size that is divable by (14*2) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) + fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # BUG: not sure what this does + fout.add_name(model_name) + # fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"]) -if len(projector) > 0: - save_model(projector, f"{args.model}/llava.projector", 'pytorch') + processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) + # breakpoint() + fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) + fout.add_array("clip.vision.image_std", processor.image_processor.image_std) -print("Done!") -print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") -print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") + fout.write_header_to_file() + fout.write_kv_data_to_file() + fout.write_tensors_to_file() + fout.close() + + +main() \ No newline at end of file diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 68805b236..f070f80ff 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -650,6 +650,32 @@ static void tmp_test_mrope_2d(struct llava_context * ctx_llava, gpt_params * par } } + +static void tmp_dump_img_embed(struct llava_context * ctx_llava, gpt_params * params) { + // auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg"); + int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); + // int ne = n_embd * image_embed->n_image_pos; + int ne = n_embd * 4; + float vals[56 * 56 * 3]; + float embd[ne]; + for (int i = 0; i < 56*56*3; i++) + { + vals[i] = (float)(i % (56 * 56)) / (56*56); + } + // auto param = &ctx_llava->ctx_clip->vision_model.hparams; + tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); + + std::ofstream outFile("img_embed.bin", std::ios::binary); + if (outFile.is_open()) { + outFile.write(reinterpret_cast(embd), ne * sizeof(float)); + + outFile.close(); + std::cout << "Data successfully written to mrope.bin" << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } +} + /* ----------------------------------------------------------------------------------------------------------------- */ @@ -693,7 +719,8 @@ int main(int argc, char ** argv) { auto ctx_llava = llava_init_context(¶ms, model); // process the prompt - tmp_test_4d_reshape(ctx_llava, ¶ms); + tmp_dump_img_embed(ctx_llava, ¶ms); + // tmp_test_4d_reshape(ctx_llava, ¶ms); // tmp_test_rope(ctx_llava, ¶ms); // tmp_test_mrope(ctx_llava, ¶ms); // tmp_test_mrope_2d(ctx_llava, ¶ms); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7df23371c..3feabfd64 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -226,6 +226,7 @@ class MODEL_ARCH(IntEnum): QWEN = auto() QWEN2 = auto() QWEN2MOE = auto() + QWEN2VL = auto() PHI2 = auto() PHI3 = auto() PLAMO = auto() @@ -388,6 +389,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PLAMO: "plamo", @@ -771,6 +773,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.QWEN2VL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.QWEN2MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM,