From 6ccf234031a259bd797137b0ce27efbcd2fc1a34 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 16 Jan 2025 14:54:31 -0700 Subject: [PATCH 01/21] Add super wip scripts for multimodal granite gguf Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 83 +++++++++++++++++-- .../llava/convert_image_encoder_to_gguf.py | 26 ++++-- examples/llava/llava_surgery_v2.py | 26 ++++-- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- 4 files changed, 119 insertions(+), 20 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 271cf2a2a..136930946 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -120,7 +120,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" - +#define KEY_VISION_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" @@ -444,8 +444,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; + int32_t image_grid_pinpoints[32]; // TODO - check to make sure this is okay for our model... int32_t image_crop_resolution; + int32_t vision_feature_layer[4]; }; struct clip_layer { @@ -615,6 +616,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; } + LOG_INF("In the graph builder...\n"); const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -666,9 +668,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 /*.mem_buffer =*/ ctx->buf_compute_meta.data(), /*.no_alloc =*/ true, }; + LOG_INF("Making the graph...\n"); struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); + LOG_INF("Graph made...\n"); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); ggml_set_name(inp_raw, "inp_raw"); @@ -751,13 +755,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } + LOG_INF("About to iterate over layers...\n"); // loop over layers if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { n_layer += 1; } + + // HACK - hold 4 vectors to stack + std::vector embeddingStack; + for (int il = 0; il < n_layer - 1; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + LOG_INF("\tLayer %d...\n", il); + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; @@ -846,7 +857,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - + // Stack embedding feature layers + // HACK - these values might be decremented unncessarily, check hparams layer; maybe this is the int feature layer index? + for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { + if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { + embeddingStack.push_back(embeddings); + LOG_INF("Saving layer %d...\n", il); + break; + } + } } // post-layernorm @@ -856,6 +875,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + LOG_INF("Layer loop over - trying to llava project...\n"); + // HACK - super hardcoded tensor concat to make sure things are working. Rewrite me + struct ggml_tensor * embeddingStack1 = ggml_concat(ctx0, embeddingStack.at(0), embeddingStack.at(1), 0); + struct ggml_tensor * embeddingStack2 = ggml_concat(ctx0, embeddingStack.at(2), embeddingStack.at(3), 0); + embeddings = ggml_concat(ctx0, embeddingStack1, embeddingStack2, 0); // llava projector if (ctx->has_llava_projector) { @@ -873,7 +897,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { + LOG_INF("proj mlp: mm 0 shape: [%d, %d, %d, %d] | embedding shape: [%d, %d, %d, %d]\n", model.mm_0_w->ne[0], model.mm_0_w->ne[1], model.mm_0_w->ne[2], model.mm_0_w->ne[3], embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + LOG_INF("proj mlp - first mulmat done\n"); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); @@ -881,6 +907,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + LOG_INF("proj mlp norm\n"); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1152,11 +1179,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + LOG_INF("forward expanding\n"); // build the graph ggml_build_forward_expand(gf, embeddings); + LOG_INF("forward expand done\n"); ggml_free(ctx0); + LOG_INF("freeing it all\n"); return gf; } @@ -1424,7 +1454,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } fin.close(); } - + LOG_INF("%s: We are up to the vision model\n", __func__); // vision model if (new_clip->has_vision_encoder) { // load vision model @@ -1452,6 +1482,33 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { hparams.image_grid_pinpoints[0]=0; } + // Load the vision feature layer indices; For most models, this will be + // an array of length one with value -1 (i.e., use last layer as visual features), + // but for IBM granite, we have multiple feature layers that get concatenated. + // + // Here, we should standardize all values to uint values so that we can use -1 as unset values. + // try { + // int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); + // int n = gguf_get_arr_n(ctx, idx); + // const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + // // HACK - need to set a good invalid number here; or maybe not, I guess it could just + // // be that it's not set in GGUF, we read all numbers as valid, and from this point on, + // // -1 is the sad one + // for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { + // hparams.vision_feature_layer[i] = vision_feature_layer[i]; + // } + // if (n < 4) + // hparams.image_grid_pinpoints[n] = -1; + // } catch (std::runtime_error & /*e*/) { + // // -1 -> taking the final layer output + // hparams.vision_feature_layer[0] = -1; + // } + // HACK for testing without GGUF hparams for now + hparams.vision_feature_layer[0] = 3; + hparams.vision_feature_layer[1] = 7; + hparams.vision_feature_layer[2] = 15; + hparams.vision_feature_layer[3] = 24; // TODO This is wrong and should be 26, but the converter seems to be chopping layers off; investigate + try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); @@ -1493,6 +1550,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } LOG_INF("\n"); + LOG_INF("vision_feature_layer: "); + for(int i = 0; i < 4 && (hparams.vision_feature_layer[i] > 0); i++) { + LOG_INF("%d ", hparams.vision_feature_layer[i]); + } + LOG_INF("\n"); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); } @@ -1504,6 +1566,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_class_embedding = false; } + LOG_INF("Has class embedding: %d", new_clip->has_class_embedding); + try { vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); @@ -1538,6 +1602,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } catch(const std::exception& /*e*/) { new_clip->has_qwen2vl_merger = false; } + LOG_INF("Loaded up to llava projection"); // LLaVA projection if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { @@ -1675,6 +1740,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->ctx_gguf = ctx; + LOG_INF("About to measure memory and build graphs...\n"); // measure mem requirement and allocate { new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -1682,6 +1748,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { clip_image_f32_batch batch; batch.size = 1; batch.data = nullptr; + LOG_INF("Entering graph...\n"); ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); @@ -2560,8 +2627,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph + LOG_INF("Doing a batch encode\n"); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + LOG_INF("did graph alloc\n"); // set inputs const auto & model = ctx->vision_model; @@ -2721,18 +2790,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } } + LOG_INF("about to do backend graph compute\n"); if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } - + LOG_INF("-----\n"); ggml_backend_graph_compute(ctx->backend, gf); + LOG_INF("did backend graph compute\n"); // the last node is the embedding tensor struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); + LOG_INF("retrieved emb tensor\n"); // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + LOG_INF("embeddings have been recopied\n"); if (ctx->has_glm_projector) { //eoi diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 4fa1d6cea..954b5b442 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -6,7 +6,7 @@ import re import torch import numpy as np from gguf import * -from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel +from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipModel, SiglipProcessor, SiglipVisionModel TEXT = "clip.text" VISION = "clip.vision" @@ -85,6 +85,8 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False, help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--clip-model-is-siglip", action="store_true", required=False, + help="the visual encoder is Siglip.") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) @@ -109,7 +111,7 @@ if args.use_f32: # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: +if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip or args.clip_model_is_siglip: vocab = None tokens = None else: @@ -137,7 +139,11 @@ ftype = 1 if args.use_f32: ftype = 0 -if args.clip_model_is_vision or args.clip_model_is_openclip: +# HACK - not sure if we need the vision model of the model + processor; check the difference +if args.clip_model_is_vision or args.clip_model_is_siglip: + model = SiglipVisionModel.from_pretrained(dir_model) + processor = None +elif args.clip_model_is_vision or args.clip_model_is_openclip: model = CLIPVisionModel.from_pretrained(dir_model) processor = None else: @@ -187,26 +193,34 @@ else: if has_text_encoder: assert t_hparams is not None assert tokens is not None + if args.clip_model_is_siglip: + text_projection_dim = 0 + else: + text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"]) # text_model hparams fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) - fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.text.projection_dim", text_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) fout.add_token_list(tokens) if has_vision_encoder: + if args.clip_model_is_siglip: + visual_projection_dim = 0 + else: + visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) # vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) - fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] # Why is this decremented? Should be 27... fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index 2d5b32fe6..5119c9ccc 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -40,7 +40,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] + clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -85,10 +85,10 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) + return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys()) def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys()) # Command-line interface setup @@ -123,14 +123,14 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -146,14 +146,24 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: assert last_checkpoint is not None - projector[name] = last_checkpoint[name].float() + # HACK - this should probably be in the second script... + new_name = name + if new_name.startswith("multi_modal_projector.linear_1"): + new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0") + elif new_name.startswith("multi_modal_projector.linear_2"): + new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2") + projector[new_name] = last_checkpoint[name].float() for name in first_mm_tensors: assert first_checkpoint is not None - projector[name] = first_checkpoint[name].float() + # HACK - this should probably be in the second script too... + new_name = name + if new_name == "image_newline": + new_name = "model.image_newline" + projector[new_name] = first_checkpoint[name].float() if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch') print("Done!") -print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index e809f05d2..a86bdb939 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -8515,7 +8515,9 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - GGML_ASSERT(i01 >= 0 && i01 < ne01); + // Copying this out for a bit while investigating due to issues like: + // https://github.com/ggerganov/llama.cpp/issues/10157 + // GGML_ASSERT(i01 >= 0 && i01 < ne01); ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), From fd0111c0436d23c9f4db5f4445ba790f56922928 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 16 Jan 2025 15:42:14 -0700 Subject: [PATCH 02/21] Add example for converting mmgranite to gguf Signed-off-by: Alex-Brooks --- examples/llava/README-mmgranite.md | 169 +++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 examples/llava/README-mmgranite.md diff --git a/examples/llava/README-mmgranite.md b/examples/llava/README-mmgranite.md new file mode 100644 index 000000000..e3328935f --- /dev/null +++ b/examples/llava/README-mmgranite.md @@ -0,0 +1,169 @@ +# Instructions to Convert Multimodal Granite -> GGUF +`export GRANITE_MODEL=/Users/alexanderjbrooks/Desktop/llava-granite-2b-vllm` + +Disclaimer that this branch is super WIP; eventually this should be combined with the main README in this directory, but separating it for now since we use a different method for converting the LLM. + + + +### 1. Running llava surgery v2. +First, we need to run the llava surgery script as shown below: + +`python llava_surgery_v2.py -C -m $GRANITE_MODEL` + +You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory. You can load them directly with pytorch and validate that they are nonempty using the snippet below. + +`ls $GRANITE_MODEL | grep -i llava` + + + +We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty: +```python +import os +import torch + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip")) +projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector")) + +assert len(encoder_tensors) > 0 +assert len(projector_tensors) > 0 +``` + +If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'mm.0.bias'`, `'mm.0.weight'`, `'mm.2.bias'`, `'mm.2.weight'`, `'model.image_newline'`) in the multimodal `projector_tensors`. + + + +### 2. Creating the Visual Component GGUF +To create the GGUF for the visual components, we need to write a config for the visual encoder. Here is an example Alex wrote for initial testing using the values in the preliminary model; if things are going wrong, there is a good chance it's a misalignment with the config here. + +Note: we refer to this file as `$VISION_CONFIG` later on. +```json +{ + "_name_or_path": "siglip-model", + "architectures": [ + "SiglipVisionModel" + ], + "image_grid_pinpoints": [ + [384,768], + [384,1152], + [384,1536], + [384,1920], + [384,2304], + [384,2688], + [384,3072], + [384,3456], + [384,3840], + [768,384], + [768,768], + [768,1152], + [768,1536], + [768,1920], + [1152,384], + [1152,768], + [1152,1152], + [1536,384], + [1536,768], + [1920,384], + [1920,768], + [2304,384], + [2688,384], + [3072,384], + [3456,384], + [3840,384] + ], + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "transformers_version": "4.45.0.dev0", + "layer_norm_eps": 1e-6, + "hidden_act": "gelu_pytorch_tanh", + "projection_dim": 0 +} +``` + +Create a new directory to hope the visual components, and copy the llava.clip/projector files, as well as the vision config into it. + +``` +ENCODER_PATH=... +VISION_CONFIG=... +mkdir $ENCODER_PATH + +cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin +cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/ +cp $VISION_CONFIG $ENCODER_PATH/config.json +``` + +At which point you should have something like this: +```bash +(venv) alexanderjbrooks@wecm-9-67-137-179 llava % ls $ENCODER_PATH +config.json llava.projector pytorch_model.bin +``` + +Now convert the components to GGUF. +```bash +python convert_image_encoder_to_gguf.py \ + -m $ENCODER_PATH \ + --llava-projector $ENCODER_PATH/llava.projector \ + --output-dir mgranite_siglip \ + --clip-model-is-vision \ + --clip-model-is-siglip +``` + +which will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.` + + + +### 3. Creating the LLM GGUF. +For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted (Alex will add support to the converter for llava next if possible, but hacking to ignore unused tensors etc with the current instructions currently results in the tokenizer embedding weights not being found). + +To do this, you can do something like the following; we assume you're setting the environment variable `LLM_EXPORT_PATH` to the place to put the exported `transformers` LLM. + +```python +import os +import transformers + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH") +if not MODEL_PATH: + raise ValueError("env var LLM_EXPORT_PATH is unset!") + +tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH) +# NOTE: need to ignore mismatched sizes for now until Alex's mmgranite PR is merged in transformers +# This also causes actual problems for the load multimodal projector, but since we only +# export the LLM, we don't care for now... +model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True) + +tokenizer.save_pretrained(LLM_EXPORT_PATH) +model.language_model.save_pretrained(LLM_EXPORT_PATH) +``` + +Now you can convert the exported LLM to GGUF with the normal converter. + +```bash +LLM_GGUF_PATH=... + +python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH +``` + + + +### 4. Running the model in llama cpp +Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage: +``` +./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \ + --mmproj $VISUAL_GGUF_PATH \ + --image cherry_blossom.jpg \ + -c 16384 \ + -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\\nCan you describe this image?\n<|assistant|>\n" +``` + From bc66d1931bf79435bf7c9f7e710bc4673bd9c5c5 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 16 Jan 2025 15:49:05 -0700 Subject: [PATCH 03/21] remove hardcoded path Signed-off-by: Alex-Brooks --- examples/llava/README-mmgranite.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/llava/README-mmgranite.md b/examples/llava/README-mmgranite.md index e3328935f..b5d59bbff 100644 --- a/examples/llava/README-mmgranite.md +++ b/examples/llava/README-mmgranite.md @@ -1,8 +1,9 @@ # Instructions to Convert Multimodal Granite -> GGUF -`export GRANITE_MODEL=/Users/alexanderjbrooks/Desktop/llava-granite-2b-vllm` - Disclaimer that this branch is super WIP; eventually this should be combined with the main README in this directory, but separating it for now since we use a different method for converting the LLM. +First, set the env var `$GRANITE_MODEL` to your vLLM/transformers format multimodal granite model. + +`export GRANITE_MODEL=...` ### 1. Running llava surgery v2. From 92046a103da1026664a95c9ff38629ba6ce23452 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 22 Jan 2025 01:27:33 -0700 Subject: [PATCH 04/21] Add vision feature layer to gguf params Signed-off-by: Alex-Brooks --- examples/llava/convert_image_encoder_to_gguf.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 954b5b442..cb0495156 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -6,7 +6,7 @@ import re import torch import numpy as np from gguf import * -from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipModel, SiglipProcessor, SiglipVisionModel +from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel TEXT = "clip.text" VISION = "clip.vision" @@ -208,6 +208,15 @@ if has_text_encoder: fout.add_token_list(tokens) if has_vision_encoder: + # vision feature layer may be an integer or an array. + # TODO - it seems like llama cpp may not handle this correctly + # normally; check if HF llava next models can run through this converter... + if "vision_feature_layer" in v_hparams: + feature_layers = v_hparams["vision_feature_layer"] + if isinstance(feature_layers, int): + feature_layers = [feature_layers] + fout.add_array("clip.vision.feature_layer", feature_layers) + if args.clip_model_is_siglip: visual_projection_dim = 0 else: From cc1c135367937261e3b4823476c596df37e25d40 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 22 Jan 2025 01:43:49 -0700 Subject: [PATCH 05/21] Clean up llava surgery and remove name substitution hacks Signed-off-by: Alex-Brooks --- examples/llava/llava_surgery_v2.py | 53 ++++++++++++++++++------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index 5119c9ccc..b07c3e323 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -33,6 +33,33 @@ def save_model(model, file_path, file_type): else: torch.save(model, file_path) +# Helpers to match weight names from specific components or +# determine if a saved shard contains that component +def is_vision_tower(weight_name): + return ( + weight_name.startswith("model.vision_tower") or + weight_name.startswith("vit.") or + weight_name.startswith("vision_tower") + ) + +def is_newline(weight_name): + return ( + weight_name.startswith("model.image_newline") or + weight_name.startswith("image_newline") + ) + +def is_mm_projector(weight_name): + return ( + weight_name.startswith("model.mm_projector") or + weight_name.startswith("vision_proj.") or + weight_name.startswith("multi_modal_projector") + ) + +def newline_criteria(checkpoint): + return any(is_newline(k) for k in checkpoint.keys()) + +def proj_criteria(checkpoint): + return any(is_mm_projector(k) for k in checkpoint.keys()) # Adapted function to clean vision tower from checkpoint def clean_vision_tower_from_checkpoint(checkpoint_path): @@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))] + clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path -def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys()) - -def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys()) - # Command-line interface setup ap = argparse.ArgumentParser() @@ -123,14 +144,14 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")] + mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -146,20 +167,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: assert last_checkpoint is not None - # HACK - this should probably be in the second script... - new_name = name - if new_name.startswith("multi_modal_projector.linear_1"): - new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0") - elif new_name.startswith("multi_modal_projector.linear_2"): - new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2") - projector[new_name] = last_checkpoint[name].float() + projector[name] = last_checkpoint[name].float() for name in first_mm_tensors: assert first_checkpoint is not None - # HACK - this should probably be in the second script too... - new_name = name - if new_name == "image_newline": - new_name = "model.image_newline" - projector[new_name] = first_checkpoint[name].float() + projector[name] = first_checkpoint[name].float() if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch') From 50504063b2c74d29e1e5d20ccadf3e9550c39537 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 22 Jan 2025 01:44:48 -0700 Subject: [PATCH 06/21] Add transformers llava next tensor name mapping Signed-off-by: Alex-Brooks --- examples/llava/convert_image_encoder_to_gguf.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index cb0495156..d51428c3f 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b def get_tensor_name(name: str) -> str: + # Standardize the transformers llava next keys for + # image newline / mm projector with the classes in haotian-liu LLaVA + if name == "image_newline": + return "model.image_newline" + if name.startswith("multi_modal_projector"): + name = name.replace("multi_modal_projector", "mm") + if name.endswith("linear_1"): + name = name.replace("linear_1", "0") + if name.endswith("linear_2"): + name = name.replace("linear_2", "1") + return name + if "projection" in name: return name if "mm_projector" in name: From 61d4ae469909e91dfb86c8cbd49cd04347977df8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 22 Jan 2025 02:27:00 -0700 Subject: [PATCH 07/21] Make siglip / openclip mutuall exclusive Signed-off-by: Alex-Brooks --- .../llava/convert_image_encoder_to_gguf.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index d51428c3f..91377c8cd 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -95,10 +95,14 @@ ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts") ap.add_argument("--clip-model-is-vision", action="store_true", required=False, help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + +# Selectable visual encoders that are compatible with this script +encoder_group = ap.add_mutually_exclusive_group() +encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--clip-model-is-siglip", action="store_true", required=False, +encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False, help="the visual encoder is Siglip.") + ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) @@ -123,7 +127,12 @@ if args.use_f32: # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip or args.clip_model_is_siglip: +if ( + args.clip_model_is_vision or + not os.path.exists(dir_model + "/vocab.json") or + args.clip_model_is_openclip or + args.clip_model_is_siglip +): vocab = None tokens = None else: @@ -151,10 +160,9 @@ ftype = 1 if args.use_f32: ftype = 0 -# HACK - not sure if we need the vision model of the model + processor; check the difference -if args.clip_model_is_vision or args.clip_model_is_siglip: +if args.clip_model_is_siglip: model = SiglipVisionModel.from_pretrained(dir_model) - processor = None + processor = None # TODO - optionally handle processor to correctly extract image stats etc elif args.clip_model_is_vision or args.clip_model_is_openclip: model = CLIPVisionModel.from_pretrained(dir_model) processor = None @@ -229,10 +237,12 @@ if has_vision_encoder: feature_layers = [feature_layers] fout.add_array("clip.vision.feature_layer", feature_layers) + # Siglip does not have a visual projector; set projection dim to 0 if args.clip_model_is_siglip: visual_projection_dim = 0 else: visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) + # vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) From 7905f9dd403cbee4700280a1fc436ff72067a772 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 26 Jan 2025 01:40:07 -0700 Subject: [PATCH 08/21] Fix projector linear substitution Signed-off-by: Alex-Brooks --- examples/llava/convert_image_encoder_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 91377c8cd..030899118 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -43,9 +43,9 @@ def get_tensor_name(name: str) -> str: return "model.image_newline" if name.startswith("multi_modal_projector"): name = name.replace("multi_modal_projector", "mm") - if name.endswith("linear_1"): + if "linear_1" in name: name = name.replace("linear_1", "0") - if name.endswith("linear_2"): + if "linear_2" in name: name = name.replace("linear_2", "1") return name @@ -251,7 +251,7 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] # Why is this decremented? Should be 27... + block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ From 987f76840a4c3f2333356e1e96b30b75ed13d3fb Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 27 Jan 2025 12:09:04 -0700 Subject: [PATCH 09/21] Fix linear 2 substitution index Signed-off-by: Alex-Brooks --- examples/llava/convert_image_encoder_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 030899118..730ee3e67 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -46,7 +46,7 @@ def get_tensor_name(name: str) -> str: if "linear_1" in name: name = name.replace("linear_1", "0") if "linear_2" in name: - name = name.replace("linear_2", "1") + name = name.replace("linear_2", "2") return name if "projection" in name: From e1ec8511213b84a43409574eedd65f07b8d777b6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 27 Jan 2025 14:45:32 -0700 Subject: [PATCH 10/21] Increase max flattened gridpoints to 64 Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 136930946..3d8811282 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -170,6 +170,9 @@ static std::string format(const char * fmt, ...) { #define TN_GLM_BOI_W "adapter.boi" #define TN_GLM_EOI_W "adapter.eoi" +// Other constants for max array lengths of hparams etc +#define MAX_IMAGE_GRID_PINPOINTS 64 +#define MAX_VISION_FEATURE_LAYERS 4 enum projector_type { PROJECTOR_TYPE_MLP, @@ -430,7 +433,6 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u // // clip layers // - struct clip_hparams { int32_t image_size; int32_t patch_size; @@ -444,9 +446,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; // TODO - check to make sure this is okay for our model... + int32_t image_grid_pinpoints[MAX_IMAGE_GRID_PINPOINTS]; int32_t image_crop_resolution; - int32_t vision_feature_layer[4]; + int32_t vision_feature_layer[MAX_VISION_FEATURE_LAYERS]; }; struct clip_layer { @@ -1473,10 +1475,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) { + LOG_INF("Grid pinpoints | max %d | actual %d ", MAX_IMAGE_GRID_PINPOINTS, n); + for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && i < n && pinpoints[i] != 0; ++i) { + LOG_INF(" %d ", i); hparams.image_grid_pinpoints[i] = pinpoints[i]; } - if (n < 32) + LOG_INF("\n"); + if (n < MAX_IMAGE_GRID_PINPOINTS) hparams.image_grid_pinpoints[n] = 0; } catch (std::runtime_error & /*e*/) { hparams.image_grid_pinpoints[0]=0; @@ -1546,7 +1551,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_grid_pinpoints: "); - for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) { + for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && (hparams.image_grid_pinpoints[i] != 0); ++i) { LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } LOG_INF("\n"); @@ -2305,7 +2310,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli if (params.image_grid_pinpoints[0] != 0) { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { + for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && params.image_grid_pinpoints[i] != 0; i+=2) { possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); } std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); From ae291e5405e0a4b6facbdc3d65d45d6babd2d323 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 27 Jan 2025 15:00:24 -0700 Subject: [PATCH 11/21] Fix hardcoded concat for multiple feature layers Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 3d8811282..0546f66bd 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -871,18 +871,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // post-layernorm + // TODO - correctly handle last layer with multiple vision feature layers if (ctx->has_post_norm) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } - LOG_INF("Layer loop over - trying to llava project...\n"); - // HACK - super hardcoded tensor concat to make sure things are working. Rewrite me - struct ggml_tensor * embeddingStack1 = ggml_concat(ctx0, embeddingStack.at(0), embeddingStack.at(1), 0); - struct ggml_tensor * embeddingStack2 = ggml_concat(ctx0, embeddingStack.at(2), embeddingStack.at(3), 0); - embeddings = ggml_concat(ctx0, embeddingStack1, embeddingStack2, 0); + LOG_INF("Stacking multiple vision feature layers\n"); + // Clobber the output embeddings with the saved items in the embedding stack vector + if(embeddingStack.size() > 0) { + embeddings = embeddingStack.at(0); + for(int i=1; i < embeddingStack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embeddingStack.at(i), 0); + } + + } + + + LOG_INF("Layer loop over - trying to llava project...\n"); // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); From ab71c9e9c4c5f89f7ecc354dd6fb3d6dd352165c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 27 Jan 2025 15:23:33 -0700 Subject: [PATCH 12/21] Pull vision feature layers out of gguf keys Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 0546f66bd..ca43e8e97 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1500,27 +1500,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // but for IBM granite, we have multiple feature layers that get concatenated. // // Here, we should standardize all values to uint values so that we can use -1 as unset values. - // try { - // int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); - // int n = gguf_get_arr_n(ctx, idx); - // const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); - // // HACK - need to set a good invalid number here; or maybe not, I guess it could just - // // be that it's not set in GGUF, we read all numbers as valid, and from this point on, - // // -1 is the sad one - // for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { - // hparams.vision_feature_layer[i] = vision_feature_layer[i]; - // } - // if (n < 4) - // hparams.image_grid_pinpoints[n] = -1; - // } catch (std::runtime_error & /*e*/) { - // // -1 -> taking the final layer output - // hparams.vision_feature_layer[0] = -1; - // } - // HACK for testing without GGUF hparams for now - hparams.vision_feature_layer[0] = 3; - hparams.vision_feature_layer[1] = 7; - hparams.vision_feature_layer[2] = 15; - hparams.vision_feature_layer[3] = 24; // TODO This is wrong and should be 26, but the converter seems to be chopping layers off; investigate + try { + LOG_INF("ABOUT TO GET VISION FEATURE LAYER KEYS\n"); + int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); + LOG_INF("VISION FEATURE LAYER IDX %d\n", idx); + int n = gguf_get_arr_n(ctx, idx); + LOG_INF("GETTING %d VISION FEATURE LAYERS \n", n); + const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + // HACK - need to set a good invalid number here; or maybe not, I guess it could just + // be that it's not set in GGUF, we read all numbers as valid, and from this point on, + // -1 is the sad one + for (int i = 0; i < MAX_VISION_FEATURE_LAYERS && i < n && vision_feature_layer[i] != 0; ++i) { + hparams.vision_feature_layer[i] = vision_feature_layer[i]; + LOG_INF("feature layer %d - %d | ", i, vision_feature_layer[i]); + } + if (n < MAX_IMAGE_GRID_PINPOINTS) + hparams.image_grid_pinpoints[n] = -1; + } catch (std::runtime_error & /*e*/) { + LOG_INF("VISION FEATURE LAYER RETRIEVAL FAILED"); + hparams.vision_feature_layer[0] = -1; + } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); From 65935431b4e521458bd9e868bed58774bc06528d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 5 Feb 2025 03:06:01 -0700 Subject: [PATCH 13/21] fix num gridpoints and use all layers Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 10 +++++++++- examples/llava/llava.cpp | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index ca43e8e97..33600b7d8 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -767,7 +767,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // HACK - hold 4 vectors to stack std::vector embeddingStack; - for (int il = 0; il < n_layer - 1; il++) { + // TODO - n_layer was previously n_layer - 1, probably to use -2 as the feature layer, + // in actuality it probably is a good idea to use that as a default, but otherwise infer + // how deep in the encoder we actually have to go if we set the hparams for the vision feature + // layer... + for (int il = 0; il < n_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states LOG_INF("\tLayer %d...\n", il); @@ -907,6 +911,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { + LOG_INF("---- MLP projector ----"); LOG_INF("proj mlp: mm 0 shape: [%d, %d, %d, %d] | embedding shape: [%d, %d, %d, %d]\n", model.mm_0_w->ne[0], model.mm_0_w->ne[1], model.mm_0_w->ne[2], model.mm_0_w->ne[3], embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); LOG_INF("proj mlp - first mulmat done\n"); @@ -1506,6 +1511,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("VISION FEATURE LAYER IDX %d\n", idx); int n = gguf_get_arr_n(ctx, idx); LOG_INF("GETTING %d VISION FEATURE LAYERS \n", n); + // TODO - fix this + LOG_INF("n_layer in hparams is: %d\n", hparams.n_layer); + const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); // HACK - need to set a good invalid number here; or maybe not, I guess it could just // be that it's not set in GGUF, we read all numbers as valid, and from this point on, diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 300714045..35049cc2c 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -355,7 +355,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const int32_t * image_grid = clip_image_grid(ctx_clip); std::vector> grid_pinpoints; - for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { + for (int i = 0; i < 64 && image_grid[i] != 0; i += 2) { grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } From d85580c41c1ddcbf84e48e733a9e15e51d921cc6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 5 Feb 2025 03:07:35 -0700 Subject: [PATCH 14/21] Avoid dropping last image encoder layer in llava models Signed-off-by: Alex-Brooks --- examples/llava/convert_image_encoder_to_gguf.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 730ee3e67..ab70a55d6 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -162,7 +162,7 @@ if args.use_f32: if args.clip_model_is_siglip: model = SiglipVisionModel.from_pretrained(dir_model) - processor = None # TODO - optionally handle processor to correctly extract image stats etc + processor = None elif args.clip_model_is_vision or args.clip_model_is_openclip: model = CLIPVisionModel.from_pretrained(dir_model) processor = None @@ -228,10 +228,12 @@ if has_text_encoder: fout.add_token_list(tokens) if has_vision_encoder: - # vision feature layer may be an integer or an array. - # TODO - it seems like llama cpp may not handle this correctly - # normally; check if HF llava next models can run through this converter... - if "vision_feature_layer" in v_hparams: + ## FIXME Need to pull this out of the overall model config, not just the top one? + # TODO or document that vision_feature_layer can be set here, but it's usually in the + # llava config and not the vision config itself; + # Handle vision feature layers in transformers, where features may be taken + # from layers that are not the last. NOTE - these values can be unsigned... + if "vision_feature_layer" in config: feature_layers = v_hparams["vision_feature_layer"] if isinstance(feature_layers, int): feature_layers = [feature_layers] @@ -251,7 +253,7 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + block_count = v_hparams["num_hidden_layers"] #- 1 if has_llava_projector else v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ @@ -319,7 +321,7 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) + # model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) From 3a191f8edbc693d9318aa0843c8d34cd80e84c5d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 5 Feb 2025 03:22:40 -0700 Subject: [PATCH 15/21] Use 10 for max number of patches Signed-off-by: Alex-Brooks --- examples/llava/llava.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 35049cc2c..48795f3f8 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -405,10 +405,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * } bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - int num_max_patches = 6; - if (clip_is_minicpmv(ctx_clip)) { - num_max_patches = 10; - } + // Minicpmv / granite vision use 10 patches + int num_max_patches = 10; if (clip_is_glm(ctx_clip)) { num_max_patches = 1; } From 188a068a0497e609d0695d291fcea24927916582 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 06:43:49 -0700 Subject: [PATCH 16/21] Standardize vision feature layers Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 97 ++++++++++--------- .../llava/convert_image_encoder_to_gguf.py | 43 +++++--- 2 files changed, 85 insertions(+), 55 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 33600b7d8..44a65e40b 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -170,10 +170,6 @@ static std::string format(const char * fmt, ...) { #define TN_GLM_BOI_W "adapter.boi" #define TN_GLM_EOI_W "adapter.eoi" -// Other constants for max array lengths of hparams etc -#define MAX_IMAGE_GRID_PINPOINTS 64 -#define MAX_VISION_FEATURE_LAYERS 4 - enum projector_type { PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP_NORM, @@ -446,9 +442,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[MAX_IMAGE_GRID_PINPOINTS]; + int32_t image_grid_pinpoints[64]; int32_t image_crop_resolution; - int32_t vision_feature_layer[MAX_VISION_FEATURE_LAYERS]; + int32_t vision_feature_layer[4]; }; struct clip_layer { @@ -759,22 +755,37 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } LOG_INF("About to iterate over layers...\n"); - // loop over layers - if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { - n_layer += 1; + // Check to see we have 1+ set vision feature layers set; otherwise it's the last layer + std::vector embedding_stack; + bool has_feature_layers = ctx->vision_model.hparams.vision_feature_layer[0] > 0; + // Determine how many encoder layers we need to process; if we have explicit vision feature + // layers, only process what we need, otherwise process all of the visual encoder layers. + int max_feature_layer = -1; + if(has_feature_layers) { + for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { + if(ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx] > max_feature_layer) { + max_feature_layer = ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]; + } + } } + if(max_feature_layer < 0) { + max_feature_layer = n_layer; + } + LOG_INF("Number of feature layers: %d\n", max_feature_layer); - // HACK - hold 4 vectors to stack - std::vector embeddingStack; - - // TODO - n_layer was previously n_layer - 1, probably to use -2 as the feature layer, - // in actuality it probably is a good idea to use that as a default, but otherwise infer - // how deep in the encoder we actually have to go if we set the hparams for the vision feature - // layer... - for (int il = 0; il < n_layer; il++) { + // loop over layers + for (int il = 0; il < max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - LOG_INF("\tLayer %d...\n", il); + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { + if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { + LOG_INF("Keeping vision feature layer: %d\n", il); + embedding_stack.push_back(embeddings); + break; + } + } //const size_t nb_q_w = model.layers[il].q_w->nb[0]; @@ -863,36 +874,34 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - // Stack embedding feature layers - // HACK - these values might be decremented unncessarily, check hparams layer; maybe this is the int feature layer index? - for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { - if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { - embeddingStack.push_back(embeddings); - LOG_INF("Saving layer %d...\n", il); - break; - } - } } // post-layernorm - // TODO - correctly handle last layer with multiple vision feature layers - if (ctx->has_post_norm) { + if (ctx->has_post_norm && max_feature_layer == n_layer) { + LOG_INF("POST NORMALIZING"); embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } - LOG_INF("Stacking multiple vision feature layers\n"); - // Clobber the output embeddings with the saved items in the embedding stack vector - if(embeddingStack.size() > 0) { - embeddings = embeddingStack.at(0); - for(int i=1; i < embeddingStack.size(); i++) { - embeddings = ggml_concat(ctx0, embeddings, embeddingStack.at(i), 0); + // final layer is a vision feature layer + for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { + if (n_layer == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { + LOG_INF("Keeping vision feature layer : %d\n", n_layer); + embedding_stack.push_back(embeddings); + break; } - } + // If feature layers are explicitly set, stack them (if we have multiple) + if(has_feature_layers && embedding_stack.size() > 0) { + LOG_INF("Stacking vision feature layers : %d\n", n_layer); + embeddings = embedding_stack.at(0); + for(unsigned long i=1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack.at(i), 0); + } + } LOG_INF("Layer loop over - trying to llava project...\n"); // llava projector @@ -1488,13 +1497,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - LOG_INF("Grid pinpoints | max %d | actual %d ", MAX_IMAGE_GRID_PINPOINTS, n); - for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && i < n && pinpoints[i] != 0; ++i) { + LOG_INF("Grid pinpoints | max %d | actual %d ", 64, n); + for (int i = 0; i < 64 && i < n && pinpoints[i] != 0; ++i) { LOG_INF(" %d ", i); hparams.image_grid_pinpoints[i] = pinpoints[i]; } LOG_INF("\n"); - if (n < MAX_IMAGE_GRID_PINPOINTS) + if (n < 64) hparams.image_grid_pinpoints[n] = 0; } catch (std::runtime_error & /*e*/) { hparams.image_grid_pinpoints[0]=0; @@ -1518,12 +1527,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // HACK - need to set a good invalid number here; or maybe not, I guess it could just // be that it's not set in GGUF, we read all numbers as valid, and from this point on, // -1 is the sad one - for (int i = 0; i < MAX_VISION_FEATURE_LAYERS && i < n && vision_feature_layer[i] != 0; ++i) { + for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { hparams.vision_feature_layer[i] = vision_feature_layer[i]; LOG_INF("feature layer %d - %d | ", i, vision_feature_layer[i]); } - if (n < MAX_IMAGE_GRID_PINPOINTS) - hparams.image_grid_pinpoints[n] = -1; + if (n < 4) + hparams.vision_feature_layer[n] = -1; } catch (std::runtime_error & /*e*/) { LOG_INF("VISION FEATURE LAYER RETRIEVAL FAILED"); hparams.vision_feature_layer[0] = -1; @@ -1566,7 +1575,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_grid_pinpoints: "); - for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && (hparams.image_grid_pinpoints[i] != 0); ++i) { + for (int i = 0; i < 64 && (hparams.image_grid_pinpoints[i] != 0); ++i) { LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } LOG_INF("\n"); @@ -2325,7 +2334,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli if (params.image_grid_pinpoints[0] != 0) { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector> possible_resolutions; - for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && params.image_grid_pinpoints[i] != 0; i+=2) { + for (int i = 0; i < 64 && params.image_grid_pinpoints[i] != 0; i+=2) { possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); } std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index ab70a55d6..414eb2838 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -227,17 +227,37 @@ if has_text_encoder: fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) fout.add_token_list(tokens) -if has_vision_encoder: - ## FIXME Need to pull this out of the overall model config, not just the top one? - # TODO or document that vision_feature_layer can be set here, but it's usually in the - # llava config and not the vision config itself; - # Handle vision feature layers in transformers, where features may be taken - # from layers that are not the last. NOTE - these values can be unsigned... + + +def get_unsigned_vision_feature_layers(v_hparams): + """ + Determine the vision feature layer(s) for the llava model, which are indices into the + hidden states of the visual encoder. Note that the hidden states array generally takes the + form: + + [, , ... ] + + so positive feature indices should be offset as n+1 to get the output of encoder block n. + We convert all vision feature layers to unsigned ints so that -1 can be used in the model + as an unset value. If no vision feature layer is found, we leave it unset. + """ + num_hidden_layers = v_hparams["num_hidden_layers"] + to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1 + feature_layers_key = None + # Key used for llava models in transformers if "vision_feature_layer" in config: - feature_layers = v_hparams["vision_feature_layer"] + feature_layers_key = "vision_feature_layer" + # Key used for llava models in the original format + elif "mm_vision_select_layer" in config: + feature_layers_key = "mm_vision_select_layer" + if feature_layers_key is not None: + feature_layers = config[feature_layers_key] if isinstance(feature_layers, int): feature_layers = [feature_layers] - fout.add_array("clip.vision.feature_layer", feature_layers) + return [to_uint(feature_layer) for feature_layer in feature_layers] + +if has_vision_encoder: + feature_layers = get_unsigned_vision_feature_layers(v_hparams) # Siglip does not have a visual projector; set projection dim to 0 if args.clip_model_is_siglip: @@ -245,7 +265,7 @@ if has_vision_encoder: else: visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) - # vision_model hparams + # set vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) @@ -253,7 +273,7 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] #- 1 if has_llava_projector else v_hparams["num_hidden_layers"] + block_count = v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ @@ -305,7 +325,8 @@ if has_vision_encoder: fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"]) if "mm_projector_type" in v_hparams: fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) - + if feature_layers: + fout.add_array("clip.vision.feature_layer", feature_layers) if processor is not None: image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue] From 2327897175f1738b50fd490c0c02e8950243b63b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 07:04:08 -0700 Subject: [PATCH 17/21] Cleanup logs Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 53 +++---------------- .../llava/convert_image_encoder_to_gguf.py | 1 - 2 files changed, 8 insertions(+), 46 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 44a65e40b..9d2922dd3 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -121,6 +121,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" #define KEY_VISION_FEATURE_LAYER "clip.vision.feature_layer" + #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" @@ -170,6 +171,7 @@ static std::string format(const char * fmt, ...) { #define TN_GLM_BOI_W "adapter.boi" #define TN_GLM_EOI_W "adapter.eoi" + enum projector_type { PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP_NORM, @@ -429,6 +431,7 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u // // clip layers // + struct clip_hparams { int32_t image_size; int32_t patch_size; @@ -614,7 +617,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; } - LOG_INF("In the graph builder...\n"); const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -666,11 +668,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 /*.mem_buffer =*/ ctx->buf_compute_meta.data(), /*.no_alloc =*/ true, }; - LOG_INF("Making the graph...\n"); struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); - LOG_INF("Graph made...\n"); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); ggml_set_name(inp_raw, "inp_raw"); @@ -753,7 +753,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } - LOG_INF("About to iterate over layers...\n"); // Check to see we have 1+ set vision feature layers set; otherwise it's the last layer std::vector embedding_stack; @@ -771,7 +770,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if(max_feature_layer < 0) { max_feature_layer = n_layer; } - LOG_INF("Number of feature layers: %d\n", max_feature_layer); // loop over layers for (int il = 0; il < max_feature_layer; il++) { @@ -781,7 +779,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // NOTE: 0 index here refers to the input to the encoder. for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { - LOG_INF("Keeping vision feature layer: %d\n", il); embedding_stack.push_back(embeddings); break; } @@ -878,7 +875,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // post-layernorm if (ctx->has_post_norm && max_feature_layer == n_layer) { - LOG_INF("POST NORMALIZING"); embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "post_ln"); @@ -888,7 +884,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // final layer is a vision feature layer for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { if (n_layer == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { - LOG_INF("Keeping vision feature layer : %d\n", n_layer); embedding_stack.push_back(embeddings); break; } @@ -896,14 +891,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // If feature layers are explicitly set, stack them (if we have multiple) if(has_feature_layers && embedding_stack.size() > 0) { - LOG_INF("Stacking vision feature layers : %d\n", n_layer); embeddings = embedding_stack.at(0); for(unsigned long i=1; i < embedding_stack.size(); i++) { embeddings = ggml_concat(ctx0, embeddings, embedding_stack.at(i), 0); } } - LOG_INF("Layer loop over - trying to llava project...\n"); // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -920,10 +913,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { - LOG_INF("---- MLP projector ----"); - LOG_INF("proj mlp: mm 0 shape: [%d, %d, %d, %d] | embedding shape: [%d, %d, %d, %d]\n", model.mm_0_w->ne[0], model.mm_0_w->ne[1], model.mm_0_w->ne[2], model.mm_0_w->ne[3], embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - LOG_INF("proj mlp - first mulmat done\n"); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); @@ -931,7 +921,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { - LOG_INF("proj mlp norm\n"); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1203,14 +1192,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } - LOG_INF("forward expanding\n"); // build the graph ggml_build_forward_expand(gf, embeddings); - LOG_INF("forward expand done\n"); ggml_free(ctx0); - LOG_INF("freeing it all\n"); return gf; } @@ -1478,7 +1464,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } fin.close(); } - LOG_INF("%s: We are up to the vision model\n", __func__); + // vision model if (new_clip->has_vision_encoder) { // load vision model @@ -1497,31 +1483,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - LOG_INF("Grid pinpoints | max %d | actual %d ", 64, n); for (int i = 0; i < 64 && i < n && pinpoints[i] != 0; ++i) { - LOG_INF(" %d ", i); hparams.image_grid_pinpoints[i] = pinpoints[i]; } - LOG_INF("\n"); if (n < 64) hparams.image_grid_pinpoints[n] = 0; } catch (std::runtime_error & /*e*/) { hparams.image_grid_pinpoints[0]=0; } - // Load the vision feature layer indices; For most models, this will be - // an array of length one with value -1 (i.e., use last layer as visual features), - // but for IBM granite, we have multiple feature layers that get concatenated. - // - // Here, we should standardize all values to uint values so that we can use -1 as unset values. + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. try { - LOG_INF("ABOUT TO GET VISION FEATURE LAYER KEYS\n"); int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); - LOG_INF("VISION FEATURE LAYER IDX %d\n", idx); int n = gguf_get_arr_n(ctx, idx); - LOG_INF("GETTING %d VISION FEATURE LAYERS \n", n); - // TODO - fix this - LOG_INF("n_layer in hparams is: %d\n", hparams.n_layer); const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); // HACK - need to set a good invalid number here; or maybe not, I guess it could just @@ -1529,12 +1505,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // -1 is the sad one for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { hparams.vision_feature_layer[i] = vision_feature_layer[i]; - LOG_INF("feature layer %d - %d | ", i, vision_feature_layer[i]); } if (n < 4) hparams.vision_feature_layer[n] = -1; } catch (std::runtime_error & /*e*/) { - LOG_INF("VISION FEATURE LAYER RETRIEVAL FAILED"); hparams.vision_feature_layer[0] = -1; } @@ -1595,8 +1569,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_class_embedding = false; } - LOG_INF("Has class embedding: %d", new_clip->has_class_embedding); - try { vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); @@ -1631,7 +1603,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } catch(const std::exception& /*e*/) { new_clip->has_qwen2vl_merger = false; } - LOG_INF("Loaded up to llava projection"); // LLaVA projection if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { @@ -1769,7 +1740,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->ctx_gguf = ctx; - LOG_INF("About to measure memory and build graphs...\n"); // measure mem requirement and allocate { new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -1777,7 +1747,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { clip_image_f32_batch batch; batch.size = 1; batch.data = nullptr; - LOG_INF("Entering graph...\n"); ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); @@ -2656,10 +2625,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph - LOG_INF("Doing a batch encode\n"); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); - LOG_INF("did graph alloc\n"); // set inputs const auto & model = ctx->vision_model; @@ -2819,22 +2786,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } } - LOG_INF("about to do backend graph compute\n"); if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } - LOG_INF("-----\n"); + ggml_backend_graph_compute(ctx->backend, gf); - LOG_INF("did backend graph compute\n"); // the last node is the embedding tensor struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); - LOG_INF("retrieved emb tensor\n"); // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); - LOG_INF("embeddings have been recopied\n"); if (ctx->has_glm_projector) { //eoi diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 414eb2838..df56bf789 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -342,7 +342,6 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: - # model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) From 78f765e8a5f4617baea395e8e8268eab330226d3 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 07:06:22 -0700 Subject: [PATCH 18/21] Update comment for vision feature layer init Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 9d2922dd3..ff90ddbf3 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1495,14 +1495,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // Load the vision feature layer indices if they are explicitly provided; // if multiple vision feature layers are present, the values will be concatenated // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to uints, + // since we use -1 as an unset value here. try { int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); int n = gguf_get_arr_n(ctx, idx); const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); - // HACK - need to set a good invalid number here; or maybe not, I guess it could just - // be that it's not set in GGUF, we read all numbers as valid, and from this point on, - // -1 is the sad one + for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { hparams.vision_feature_layer[i] = vision_feature_layer[i]; } From 17bf6ad3042456d89263f3913c411fc9ef70a5b7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 07:18:56 -0700 Subject: [PATCH 19/21] Update notes for alternative to legacy llm conversion script Signed-off-by: Alex-Brooks --- examples/llava/README-mmgranite.md | 170 ----------------------------- examples/llava/README.md | 17 +++ 2 files changed, 17 insertions(+), 170 deletions(-) delete mode 100644 examples/llava/README-mmgranite.md diff --git a/examples/llava/README-mmgranite.md b/examples/llava/README-mmgranite.md deleted file mode 100644 index b5d59bbff..000000000 --- a/examples/llava/README-mmgranite.md +++ /dev/null @@ -1,170 +0,0 @@ -# Instructions to Convert Multimodal Granite -> GGUF -Disclaimer that this branch is super WIP; eventually this should be combined with the main README in this directory, but separating it for now since we use a different method for converting the LLM. - -First, set the env var `$GRANITE_MODEL` to your vLLM/transformers format multimodal granite model. - -`export GRANITE_MODEL=...` - - -### 1. Running llava surgery v2. -First, we need to run the llava surgery script as shown below: - -`python llava_surgery_v2.py -C -m $GRANITE_MODEL` - -You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory. You can load them directly with pytorch and validate that they are nonempty using the snippet below. - -`ls $GRANITE_MODEL | grep -i llava` - - - -We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty: -```python -import os -import torch - -MODEL_PATH = os.getenv("GRANITE_MODEL") -if not MODEL_PATH: - raise ValueError("env var GRANITE_MODEL is unset!") - -encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip")) -projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector")) - -assert len(encoder_tensors) > 0 -assert len(projector_tensors) > 0 -``` - -If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'mm.0.bias'`, `'mm.0.weight'`, `'mm.2.bias'`, `'mm.2.weight'`, `'model.image_newline'`) in the multimodal `projector_tensors`. - - - -### 2. Creating the Visual Component GGUF -To create the GGUF for the visual components, we need to write a config for the visual encoder. Here is an example Alex wrote for initial testing using the values in the preliminary model; if things are going wrong, there is a good chance it's a misalignment with the config here. - -Note: we refer to this file as `$VISION_CONFIG` later on. -```json -{ - "_name_or_path": "siglip-model", - "architectures": [ - "SiglipVisionModel" - ], - "image_grid_pinpoints": [ - [384,768], - [384,1152], - [384,1536], - [384,1920], - [384,2304], - [384,2688], - [384,3072], - [384,3456], - [384,3840], - [768,384], - [768,768], - [768,1152], - [768,1536], - [768,1920], - [1152,384], - [1152,768], - [1152,1152], - [1536,384], - [1536,768], - [1920,384], - [1920,768], - [2304,384], - [2688,384], - [3072,384], - [3456,384], - [3840,384] - ], - "hidden_size": 1152, - "image_size": 384, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "patch_size": 14, - "transformers_version": "4.45.0.dev0", - "layer_norm_eps": 1e-6, - "hidden_act": "gelu_pytorch_tanh", - "projection_dim": 0 -} -``` - -Create a new directory to hope the visual components, and copy the llava.clip/projector files, as well as the vision config into it. - -``` -ENCODER_PATH=... -VISION_CONFIG=... -mkdir $ENCODER_PATH - -cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin -cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/ -cp $VISION_CONFIG $ENCODER_PATH/config.json -``` - -At which point you should have something like this: -```bash -(venv) alexanderjbrooks@wecm-9-67-137-179 llava % ls $ENCODER_PATH -config.json llava.projector pytorch_model.bin -``` - -Now convert the components to GGUF. -```bash -python convert_image_encoder_to_gguf.py \ - -m $ENCODER_PATH \ - --llava-projector $ENCODER_PATH/llava.projector \ - --output-dir mgranite_siglip \ - --clip-model-is-vision \ - --clip-model-is-siglip -``` - -which will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.` - - - -### 3. Creating the LLM GGUF. -For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted (Alex will add support to the converter for llava next if possible, but hacking to ignore unused tensors etc with the current instructions currently results in the tokenizer embedding weights not being found). - -To do this, you can do something like the following; we assume you're setting the environment variable `LLM_EXPORT_PATH` to the place to put the exported `transformers` LLM. - -```python -import os -import transformers - -MODEL_PATH = os.getenv("GRANITE_MODEL") -if not MODEL_PATH: - raise ValueError("env var GRANITE_MODEL is unset!") - -LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH") -if not MODEL_PATH: - raise ValueError("env var LLM_EXPORT_PATH is unset!") - -tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH) -# NOTE: need to ignore mismatched sizes for now until Alex's mmgranite PR is merged in transformers -# This also causes actual problems for the load multimodal projector, but since we only -# export the LLM, we don't care for now... -model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True) - -tokenizer.save_pretrained(LLM_EXPORT_PATH) -model.language_model.save_pretrained(LLM_EXPORT_PATH) -``` - -Now you can convert the exported LLM to GGUF with the normal converter. - -```bash -LLM_GGUF_PATH=... - -python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH -``` - - - -### 4. Running the model in llama cpp -Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage: -``` -./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \ - --mmproj $VISUAL_GGUF_PATH \ - --image cherry_blossom.jpg \ - -c 16384 \ - -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\\nCan you describe this image?\n<|assistant|>\n" -``` - diff --git a/examples/llava/README.md b/examples/llava/README.md index 012451361..57684b623 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -102,6 +102,23 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) **note** llava-1.6 greatly benefits from batched prompt processing (defaults work) +**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model. + +```python +import os +import transformers + +model_path = ... +llm_export_path = ... + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) +model = transformers.AutoModelForImageTextToText.from_pretrained(model_path) + +tokenizer.save_pretrained(llm_export_path) +model.language_model.save_pretrained(llm_export_path) +``` + +Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures. ## llava-cli templating and llava-1.6 prompting From 06703820dcfce6598f5051b977a74e77ea0e9d4b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 07:23:07 -0700 Subject: [PATCH 20/21] Fix notes rendering Signed-off-by: Alex-Brooks --- examples/llava/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/llava/README.md b/examples/llava/README.md index 57684b623..0e3c32032 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -101,7 +101,9 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow ``` **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) + **note** llava-1.6 greatly benefits from batched prompt processing (defaults work) + **note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model. ```python From 262000fa4de0b6e08514b1fdabc54d4daccbdfd2 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 10 Feb 2025 07:52:51 -0700 Subject: [PATCH 21/21] Add v prefix to vision feature layer log Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index ff90ddbf3..5ecb5bacc 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1553,7 +1553,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } LOG_INF("\n"); - LOG_INF("vision_feature_layer: "); + LOG_INF("v_vision_feature_layer: "); for(int i = 0; i < 4 && (hparams.vision_feature_layer[i] > 0); i++) { LOG_INF("%d ", hparams.vision_feature_layer[i]); }