Add super wip scripts for multimodal granite gguf
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
d774ab3acc
commit
6ccf234031
4 changed files with 119 additions and 20 deletions
|
@ -120,7 +120,7 @@ static std::string format(const char * fmt, ...) {
|
||||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_VISION_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||||
|
@ -444,8 +444,9 @@ struct clip_hparams {
|
||||||
|
|
||||||
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
|
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
|
||||||
|
|
||||||
int32_t image_grid_pinpoints[32];
|
int32_t image_grid_pinpoints[32]; // TODO - check to make sure this is okay for our model...
|
||||||
int32_t image_crop_resolution;
|
int32_t image_crop_resolution;
|
||||||
|
int32_t vision_feature_layer[4];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
@ -615,6 +616,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
LOG_INF("In the graph builder...\n");
|
||||||
|
|
||||||
const auto & model = ctx->vision_model;
|
const auto & model = ctx->vision_model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
@ -666,9 +668,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
LOG_INF("Making the graph...\n");
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
LOG_INF("Graph made...\n");
|
||||||
|
|
||||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
||||||
ggml_set_name(inp_raw, "inp_raw");
|
ggml_set_name(inp_raw, "inp_raw");
|
||||||
|
@ -751,13 +755,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||||
}
|
}
|
||||||
|
LOG_INF("About to iterate over layers...\n");
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
||||||
n_layer += 1;
|
n_layer += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HACK - hold 4 vectors to stack
|
||||||
|
std::vector<struct ggml_tensor *> embeddingStack;
|
||||||
|
|
||||||
for (int il = 0; il < n_layer - 1; il++) {
|
for (int il = 0; il < n_layer - 1; il++) {
|
||||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||||
|
LOG_INF("\tLayer %d...\n", il);
|
||||||
|
|
||||||
|
|
||||||
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
||||||
|
|
||||||
|
@ -846,7 +857,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
|
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
|
// Stack embedding feature layers
|
||||||
|
// HACK - these values might be decremented unncessarily, check hparams layer; maybe this is the int feature layer index?
|
||||||
|
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
|
||||||
|
if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
|
||||||
|
embeddingStack.push_back(embeddings);
|
||||||
|
LOG_INF("Saving layer %d...\n", il);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// post-layernorm
|
// post-layernorm
|
||||||
|
@ -856,6 +875,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||||
}
|
}
|
||||||
|
LOG_INF("Layer loop over - trying to llava project...\n");
|
||||||
|
// HACK - super hardcoded tensor concat to make sure things are working. Rewrite me
|
||||||
|
struct ggml_tensor * embeddingStack1 = ggml_concat(ctx0, embeddingStack.at(0), embeddingStack.at(1), 0);
|
||||||
|
struct ggml_tensor * embeddingStack2 = ggml_concat(ctx0, embeddingStack.at(2), embeddingStack.at(3), 0);
|
||||||
|
embeddings = ggml_concat(ctx0, embeddingStack1, embeddingStack2, 0);
|
||||||
|
|
||||||
// llava projector
|
// llava projector
|
||||||
if (ctx->has_llava_projector) {
|
if (ctx->has_llava_projector) {
|
||||||
|
@ -873,7 +897,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
// llava projector
|
// llava projector
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||||
|
LOG_INF("proj mlp: mm 0 shape: [%d, %d, %d, %d] | embedding shape: [%d, %d, %d, %d]\n", model.mm_0_w->ne[0], model.mm_0_w->ne[1], model.mm_0_w->ne[2], model.mm_0_w->ne[3], embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
LOG_INF("proj mlp - first mulmat done\n");
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
@ -881,6 +907,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
}
|
}
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
|
LOG_INF("proj mlp norm\n");
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
||||||
|
@ -1152,11 +1179,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||||
}
|
}
|
||||||
|
LOG_INF("forward expanding\n");
|
||||||
|
|
||||||
// build the graph
|
// build the graph
|
||||||
ggml_build_forward_expand(gf, embeddings);
|
ggml_build_forward_expand(gf, embeddings);
|
||||||
|
LOG_INF("forward expand done\n");
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
LOG_INF("freeing it all\n");
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
@ -1424,7 +1454,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
}
|
}
|
||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
|
LOG_INF("%s: We are up to the vision model\n", __func__);
|
||||||
// vision model
|
// vision model
|
||||||
if (new_clip->has_vision_encoder) {
|
if (new_clip->has_vision_encoder) {
|
||||||
// load vision model
|
// load vision model
|
||||||
|
@ -1452,6 +1482,33 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
hparams.image_grid_pinpoints[0]=0;
|
hparams.image_grid_pinpoints[0]=0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the vision feature layer indices; For most models, this will be
|
||||||
|
// an array of length one with value -1 (i.e., use last layer as visual features),
|
||||||
|
// but for IBM granite, we have multiple feature layers that get concatenated.
|
||||||
|
//
|
||||||
|
// Here, we should standardize all values to uint values so that we can use -1 as unset values.
|
||||||
|
// try {
|
||||||
|
// int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER);
|
||||||
|
// int n = gguf_get_arr_n(ctx, idx);
|
||||||
|
// const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||||
|
// // HACK - need to set a good invalid number here; or maybe not, I guess it could just
|
||||||
|
// // be that it's not set in GGUF, we read all numbers as valid, and from this point on,
|
||||||
|
// // -1 is the sad one
|
||||||
|
// for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) {
|
||||||
|
// hparams.vision_feature_layer[i] = vision_feature_layer[i];
|
||||||
|
// }
|
||||||
|
// if (n < 4)
|
||||||
|
// hparams.image_grid_pinpoints[n] = -1;
|
||||||
|
// } catch (std::runtime_error & /*e*/) {
|
||||||
|
// // -1 -> taking the final layer output
|
||||||
|
// hparams.vision_feature_layer[0] = -1;
|
||||||
|
// }
|
||||||
|
// HACK for testing without GGUF hparams for now
|
||||||
|
hparams.vision_feature_layer[0] = 3;
|
||||||
|
hparams.vision_feature_layer[1] = 7;
|
||||||
|
hparams.vision_feature_layer[2] = 15;
|
||||||
|
hparams.vision_feature_layer[3] = 24; // TODO This is wrong and should be 26, but the converter seems to be chopping layers off; investigate
|
||||||
|
|
||||||
try {
|
try {
|
||||||
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
|
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
|
||||||
strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
|
strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
|
||||||
|
@ -1493,6 +1550,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
|
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
|
||||||
}
|
}
|
||||||
LOG_INF("\n");
|
LOG_INF("\n");
|
||||||
|
LOG_INF("vision_feature_layer: ");
|
||||||
|
for(int i = 0; i < 4 && (hparams.vision_feature_layer[i] > 0); i++) {
|
||||||
|
LOG_INF("%d ", hparams.vision_feature_layer[i]);
|
||||||
|
}
|
||||||
|
LOG_INF("\n");
|
||||||
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
|
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1504,6 +1566,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
new_clip->has_class_embedding = false;
|
new_clip->has_class_embedding = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_INF("Has class embedding: %d", new_clip->has_class_embedding);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
|
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
|
||||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
|
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
|
||||||
|
@ -1538,6 +1602,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
} catch(const std::exception& /*e*/) {
|
} catch(const std::exception& /*e*/) {
|
||||||
new_clip->has_qwen2vl_merger = false;
|
new_clip->has_qwen2vl_merger = false;
|
||||||
}
|
}
|
||||||
|
LOG_INF("Loaded up to llava projection");
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
|
@ -1675,6 +1740,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
|
|
||||||
new_clip->ctx_gguf = ctx;
|
new_clip->ctx_gguf = ctx;
|
||||||
|
|
||||||
|
LOG_INF("About to measure memory and build graphs...\n");
|
||||||
// measure mem requirement and allocate
|
// measure mem requirement and allocate
|
||||||
{
|
{
|
||||||
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
|
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||||
|
@ -1682,6 +1748,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
clip_image_f32_batch batch;
|
clip_image_f32_batch batch;
|
||||||
batch.size = 1;
|
batch.size = 1;
|
||||||
batch.data = nullptr;
|
batch.data = nullptr;
|
||||||
|
LOG_INF("Entering graph...\n");
|
||||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
||||||
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
||||||
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
||||||
|
@ -2560,8 +2627,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the inference graph
|
// build the inference graph
|
||||||
|
LOG_INF("Doing a batch encode\n");
|
||||||
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
|
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
|
||||||
ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
|
ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
|
||||||
|
LOG_INF("did graph alloc\n");
|
||||||
|
|
||||||
// set inputs
|
// set inputs
|
||||||
const auto & model = ctx->vision_model;
|
const auto & model = ctx->vision_model;
|
||||||
|
@ -2721,18 +2790,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
LOG_INF("about to do backend graph compute\n");
|
||||||
|
|
||||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||||
}
|
}
|
||||||
|
LOG_INF("-----\n");
|
||||||
ggml_backend_graph_compute(ctx->backend, gf);
|
ggml_backend_graph_compute(ctx->backend, gf);
|
||||||
|
LOG_INF("did backend graph compute\n");
|
||||||
|
|
||||||
// the last node is the embedding tensor
|
// the last node is the embedding tensor
|
||||||
struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
||||||
|
LOG_INF("retrieved emb tensor\n");
|
||||||
|
|
||||||
// copy the embeddings to the location passed by the user
|
// copy the embeddings to the location passed by the user
|
||||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||||
|
LOG_INF("embeddings have been recopied\n");
|
||||||
|
|
||||||
if (ctx->has_glm_projector) {
|
if (ctx->has_glm_projector) {
|
||||||
//eoi
|
//eoi
|
||||||
|
|
|
@ -6,7 +6,7 @@ import re
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gguf import *
|
from gguf import *
|
||||||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipModel, SiglipProcessor, SiglipVisionModel
|
||||||
|
|
||||||
TEXT = "clip.text"
|
TEXT = "clip.text"
|
||||||
VISION = "clip.vision"
|
VISION = "clip.vision"
|
||||||
|
@ -85,6 +85,8 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||||
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||||
|
ap.add_argument("--clip-model-is-siglip", action="store_true", required=False,
|
||||||
|
help="the visual encoder is Siglip.")
|
||||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
||||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||||
|
@ -109,7 +111,7 @@ if args.use_f32:
|
||||||
# output in the same directory as the model if output_dir is None
|
# output in the same directory as the model if output_dir is None
|
||||||
dir_model = args.model_dir
|
dir_model = args.model_dir
|
||||||
|
|
||||||
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
|
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip or args.clip_model_is_siglip:
|
||||||
vocab = None
|
vocab = None
|
||||||
tokens = None
|
tokens = None
|
||||||
else:
|
else:
|
||||||
|
@ -137,7 +139,11 @@ ftype = 1
|
||||||
if args.use_f32:
|
if args.use_f32:
|
||||||
ftype = 0
|
ftype = 0
|
||||||
|
|
||||||
if args.clip_model_is_vision or args.clip_model_is_openclip:
|
# HACK - not sure if we need the vision model of the model + processor; check the difference
|
||||||
|
if args.clip_model_is_vision or args.clip_model_is_siglip:
|
||||||
|
model = SiglipVisionModel.from_pretrained(dir_model)
|
||||||
|
processor = None
|
||||||
|
elif args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||||
model = CLIPVisionModel.from_pretrained(dir_model)
|
model = CLIPVisionModel.from_pretrained(dir_model)
|
||||||
processor = None
|
processor = None
|
||||||
else:
|
else:
|
||||||
|
@ -187,26 +193,34 @@ else:
|
||||||
if has_text_encoder:
|
if has_text_encoder:
|
||||||
assert t_hparams is not None
|
assert t_hparams is not None
|
||||||
assert tokens is not None
|
assert tokens is not None
|
||||||
|
if args.clip_model_is_siglip:
|
||||||
|
text_projection_dim = 0
|
||||||
|
else:
|
||||||
|
text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
|
||||||
# text_model hparams
|
# text_model hparams
|
||||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
||||||
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
|
fout.add_uint32("clip.text.projection_dim", text_projection_dim)
|
||||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
||||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
||||||
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||||
fout.add_token_list(tokens)
|
fout.add_token_list(tokens)
|
||||||
|
|
||||||
if has_vision_encoder:
|
if has_vision_encoder:
|
||||||
|
if args.clip_model_is_siglip:
|
||||||
|
visual_projection_dim = 0
|
||||||
|
else:
|
||||||
|
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
|
||||||
# vision_model hparams
|
# vision_model hparams
|
||||||
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||||
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
||||||
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
|
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
|
||||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] # Why is this decremented? Should be 27...
|
||||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
# /**
|
# /**
|
||||||
# "image_grid_pinpoints": [
|
# "image_grid_pinpoints": [
|
||||||
|
|
|
@ -40,7 +40,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
# file_type = 'pytorch'
|
# file_type = 'pytorch'
|
||||||
model_path = os.path.dirname(checkpoint_path)
|
model_path = os.path.dirname(checkpoint_path)
|
||||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit.") or k.startswith("vision_tower"))]
|
||||||
|
|
||||||
if len(clip_tensors) > 0:
|
if len(clip_tensors) > 0:
|
||||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||||
|
@ -85,10 +85,10 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||||
return newline_checkpoint_path, projector_checkpoint_path
|
return newline_checkpoint_path, projector_checkpoint_path
|
||||||
|
|
||||||
def newline_criteria(checkpoint):
|
def newline_criteria(checkpoint):
|
||||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
return any(k.startswith("model.image_newline") or k.startswith("image_newline") for k in checkpoint.keys())
|
||||||
|
|
||||||
def proj_criteria(checkpoint):
|
def proj_criteria(checkpoint):
|
||||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector") for k in checkpoint.keys())
|
||||||
|
|
||||||
|
|
||||||
# Command-line interface setup
|
# Command-line interface setup
|
||||||
|
@ -123,14 +123,14 @@ first_checkpoint = None
|
||||||
if newline_checkpoint_path is not None:
|
if newline_checkpoint_path is not None:
|
||||||
print(f"Taking newline from {newline_checkpoint_path}")
|
print(f"Taking newline from {newline_checkpoint_path}")
|
||||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline") or k.startswith("image_newline")]
|
||||||
|
|
||||||
# Load the checkpoint
|
# Load the checkpoint
|
||||||
mm_tensors = []
|
mm_tensors = []
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
if projector_checkpoint_path is not None:
|
if projector_checkpoint_path is not None:
|
||||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.") or k.startswith("multi_modal_projector")]
|
||||||
|
|
||||||
if len(mm_tensors) == 0:
|
if len(mm_tensors) == 0:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
|
@ -146,14 +146,24 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
||||||
projector = {}
|
projector = {}
|
||||||
for name in mm_tensors:
|
for name in mm_tensors:
|
||||||
assert last_checkpoint is not None
|
assert last_checkpoint is not None
|
||||||
projector[name] = last_checkpoint[name].float()
|
# HACK - this should probably be in the second script...
|
||||||
|
new_name = name
|
||||||
|
if new_name.startswith("multi_modal_projector.linear_1"):
|
||||||
|
new_name = new_name.replace("multi_modal_projector.linear_1", "mm.0")
|
||||||
|
elif new_name.startswith("multi_modal_projector.linear_2"):
|
||||||
|
new_name = new_name.replace("multi_modal_projector.linear_2", "mm.2")
|
||||||
|
projector[new_name] = last_checkpoint[name].float()
|
||||||
for name in first_mm_tensors:
|
for name in first_mm_tensors:
|
||||||
assert first_checkpoint is not None
|
assert first_checkpoint is not None
|
||||||
projector[name] = first_checkpoint[name].float()
|
# HACK - this should probably be in the second script too...
|
||||||
|
new_name = name
|
||||||
|
if new_name == "image_newline":
|
||||||
|
new_name = "model.image_newline"
|
||||||
|
projector[new_name] = first_checkpoint[name].float()
|
||||||
|
|
||||||
if len(projector) > 0:
|
if len(projector) > 0:
|
||||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||||
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
||||||
|
|
|
@ -8515,7 +8515,9 @@ static void ggml_compute_forward_get_rows_f32(
|
||||||
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
|
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
|
||||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||||
|
|
||||||
GGML_ASSERT(i01 >= 0 && i01 < ne01);
|
// Copying this out for a bit while investigating due to issues like:
|
||||||
|
// https://github.com/ggerganov/llama.cpp/issues/10157
|
||||||
|
// GGML_ASSERT(i01 >= 0 && i01 < ne01);
|
||||||
|
|
||||||
ggml_vec_cpy_f32(nc,
|
ggml_vec_cpy_f32(nc,
|
||||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue