From 6ccf234031a259bd797137b0ce27efbcd2fc1a34 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 16 Jan 2025 14:54:31 -0700 Subject: [PATCH] Add super wip scripts for multimodal granite gguf Signed-off-by: Alex-Brooks --- examples/llava/clip.cpp | 83 +++++++++++++++++-- .../llava/convert_image_encoder_to_gguf.py | 26 ++++-- examples/llava/llava_surgery_v2.py | 26 ++++-- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- 4 files changed, 119 insertions(+), 20 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 271cf2a2a..136930946 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -120,7 +120,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" - +#define KEY_VISION_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" @@ -444,8 +444,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; + int32_t image_grid_pinpoints[32]; // TODO - check to make sure this is okay for our model... int32_t image_crop_resolution; + int32_t vision_feature_layer[4]; }; struct clip_layer { @@ -615,6 +616,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; } + LOG_INF("In the graph builder...\n"); const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -666,9 +668,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 /*.mem_buffer =*/ ctx->buf_compute_meta.data(), /*.no_alloc =*/ true, }; + LOG_INF("Making the graph...\n"); struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); + LOG_INF("Graph made...\n"); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); ggml_set_name(inp_raw, "inp_raw"); @@ -751,13 +755,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } + LOG_INF("About to iterate over layers...\n"); // loop over layers if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { n_layer += 1; } + + // HACK - hold 4 vectors to stack + std::vector embeddingStack; + for (int il = 0; il < n_layer - 1; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + LOG_INF("\tLayer %d...\n", il); + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; @@ -846,7 +857,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - + // Stack embedding feature layers + // HACK - these values might be decremented unncessarily, check hparams layer; maybe this is the int feature layer index? + for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) { + if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) { + embeddingStack.push_back(embeddings); + LOG_INF("Saving layer %d...\n", il); + break; + } + } } // post-layernorm @@ -856,6 +875,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + LOG_INF("Layer loop over - trying to llava project...\n"); + // HACK - super hardcoded tensor concat to make sure things are working. Rewrite me + struct ggml_tensor * embeddingStack1 = ggml_concat(ctx0, embeddingStack.at(0), embeddingStack.at(1), 0); + struct ggml_tensor * embeddingStack2 = ggml_concat(ctx0, embeddingStack.at(2), embeddingStack.at(3), 0); + embeddings = ggml_concat(ctx0, embeddingStack1, embeddingStack2, 0); // llava projector if (ctx->has_llava_projector) { @@ -873,7 +897,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { + LOG_INF("proj mlp: mm 0 shape: [%d, %d, %d, %d] | embedding shape: [%d, %d, %d, %d]\n", model.mm_0_w->ne[0], model.mm_0_w->ne[1], model.mm_0_w->ne[2], model.mm_0_w->ne[3], embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + LOG_INF("proj mlp - first mulmat done\n"); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); @@ -881,6 +907,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + LOG_INF("proj mlp norm\n"); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1152,11 +1179,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + LOG_INF("forward expanding\n"); // build the graph ggml_build_forward_expand(gf, embeddings); + LOG_INF("forward expand done\n"); ggml_free(ctx0); + LOG_INF("freeing it all\n"); return gf; } @@ -1424,7 +1454,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } fin.close(); } - + LOG_INF("%s: We are up to the vision model\n", __func__); // vision model if (new_clip->has_vision_encoder) { // load vision model @@ -1452,6 +1482,33 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { hparams.image_grid_pinpoints[0]=0; } + // Load the vision feature layer indices; For most models, this will be + // an array of length one with value -1 (i.e., use last layer as visual features), + // but for IBM granite, we have multiple feature layers that get concatenated. + // + // Here, we should standardize all values to uint values so that we can use -1 as unset values. + // try { + // int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER); + // int n = gguf_get_arr_n(ctx, idx); + // const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + // // HACK - need to set a good invalid number here; or maybe not, I guess it could just + // // be that it's not set in GGUF, we read all numbers as valid, and from this point on, + // // -1 is the sad one + // for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) { + // hparams.vision_feature_layer[i] = vision_feature_layer[i]; + // } + // if (n < 4) + // hparams.image_grid_pinpoints[n] = -1; + // } catch (std::runtime_error & /*e*/) { + // // -1 -> taking the final layer output + // hparams.vision_feature_layer[0] = -1; + // } + // HACK for testing without GGUF hparams for now + hparams.vision_feature_layer[0] = 3; + hparams.vision_feature_layer[1] = 7; + hparams.vision_feature_layer[2] = 15; + hparams.vision_feature_layer[3] = 24; // TODO This is wrong and should be 26, but the converter seems to be chopping layers off; investigate + try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); @@ -1493,6 +1550,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } LOG_INF("\n"); + LOG_INF("vision_feature_layer: "); + for(int i = 0; i < 4 && (hparams.vision_feature_layer[i] > 0); i++) { + LOG_INF("%d ", hparams.vision_feature_layer[i]); + } + LOG_INF("\n"); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); } @@ -1504,6 +1566,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_class_embedding = false; } + LOG_INF("Has class embedding: %d", new_clip->has_class_embedding); + try { vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); @@ -1538,6 +1602,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } catch(const std::exception& /*e*/) { new_clip->has_qwen2vl_merger = false; } + LOG_INF("Loaded up to llava projection"); // LLaVA projection if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { @@ -1675,6 +1740,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->ctx_gguf = ctx; + LOG_INF("About to measure memory and build graphs...\n"); // measure mem requirement and allocate { new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -1682,6 +1748,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { clip_image_f32_batch batch; batch.size = 1; batch.data = nullptr; + LOG_INF("Entering graph...\n"); ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); @@ -2560,8 +2627,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph + LOG_INF("Doing a batch encode\n"); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + LOG_INF("did graph alloc\n"); // set inputs const auto & model = ctx->vision_model; @@ -2721,18 +2790,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } } + LOG_INF("about to do backend graph compute\n"); if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } - + LOG_INF("-----\n"); ggml_backend_graph_compute(ctx->backend, gf); + LOG_INF("did backend graph compute\n"); // the last node is the embedding tensor struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); + LOG_INF("retrieved emb tensor\n"); // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + LOG_INF("embeddings have been recopied\n"); if (ctx->has_glm_projector) { //eoi diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 4fa1d6cea..954b5b442 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -6,7 +6,7 @@ import re import torch import numpy as np from gguf import * -from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel +from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipModel, SiglipProcessor, SiglipVisionModel TEXT = "clip.text" VISION = "clip.vision" @@ -85,6 +85,8 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False, help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--clip-model-is-siglip", action="store_true", required=False, + help="the visual encoder is Siglip.") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) @@ -109,7 +111,7 @@ if args.use_f32: # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: +if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip or args.clip_model_is_siglip: vocab = None tokens = None else: @@ -137,7 +139,11 @@ ftype = 1 if args.use_f32: ftype = 0 -if args.clip_model_is_vision or args.clip_model_is_openclip: +# HACK - not sure if we need the vision model of the model + processor; check the difference +if args.clip_model_is_vision or args.clip_model_is_siglip: + model = SiglipVisionModel.from_pretrained(dir_model) + processor = None +elif args.clip_model_is_vision or args.clip_model_is_openclip: model = CLIPVisionModel.from_pretrained(dir_model) processor = None else: @@ -187,26 +193,34 @@ else: if has_text_encoder: assert t_hparams is not None assert tokens is not None + if args.clip_model_is_siglip: + text_projection_dim = 0 + else: + text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"]) # text_model hparams fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) - fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.text.projection_dim", text_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) fout.add_token_list(tokens) if has_vision_encoder: + if args.clip_model_is_siglip: + visual_projection_dim = 0 + else: + visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) # vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) - fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] # Why is this decremented? Should be 27... fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index 2d5b32fe6..5119c9ccc 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -40,7 +40,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] + clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -85,10 +85,10 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) + return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys()) def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys()) # Command-line interface setup @@ -123,14 +123,14 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -146,14 +146,24 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: assert last_checkpoint is not None - projector[name] = last_checkpoint[name].float() + # HACK - this should probably be in the second script... + new_name = name + if new_name.startswith("multi_modal_projector.linear_1"): + new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0") + elif new_name.startswith("multi_modal_projector.linear_2"): + new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2") + projector[new_name] = last_checkpoint[name].float() for name in first_mm_tensors: assert first_checkpoint is not None - projector[name] = first_checkpoint[name].float() + # HACK - this should probably be in the second script too... + new_name = name + if new_name == "image_newline": + new_name = "model.image_newline" + projector[new_name] = first_checkpoint[name].float() if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch') print("Done!") -print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index e809f05d2..a86bdb939 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -8515,7 +8515,9 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - GGML_ASSERT(i01 >= 0 && i01 < ne01); + // Copying this out for a bit while investigating due to issues like: + // https://github.com/ggerganov/llama.cpp/issues/10157 + // GGML_ASSERT(i01 >= 0 && i01 < ne01); ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),