Standardize vision feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex-Brooks 2025-02-10 06:43:49 -07:00
parent 3a191f8edb
commit 188a068a04
2 changed files with 85 additions and 55 deletions

View file

@ -170,10 +170,6 @@ static std::string format(const char * fmt, ...) {
#define TN_GLM_BOI_W "adapter.boi"
#define TN_GLM_EOI_W "adapter.eoi"
// Other constants for max array lengths of hparams etc
#define MAX_IMAGE_GRID_PINPOINTS 64
#define MAX_VISION_FEATURE_LAYERS 4
enum projector_type {
PROJECTOR_TYPE_MLP,
PROJECTOR_TYPE_MLP_NORM,
@ -446,9 +442,9 @@ struct clip_hparams {
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
int32_t image_grid_pinpoints[MAX_IMAGE_GRID_PINPOINTS];
int32_t image_grid_pinpoints[64];
int32_t image_crop_resolution;
int32_t vision_feature_layer[MAX_VISION_FEATURE_LAYERS];
int32_t vision_feature_layer[4];
};
struct clip_layer {
@ -759,22 +755,37 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
LOG_INF("About to iterate over layers...\n");
// loop over layers
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
// Check to see we have 1+ set vision feature layers set; otherwise it's the last layer
std::vector<struct ggml_tensor *> embedding_stack;
bool has_feature_layers = ctx->vision_model.hparams.vision_feature_layer[0] > 0;
// Determine how many encoder layers we need to process; if we have explicit vision feature
// layers, only process what we need, otherwise process all of the visual encoder layers.
int max_feature_layer = -1;
if(has_feature_layers) {
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
if(ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx] > max_feature_layer) {
max_feature_layer = ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx];
}
}
}
if(max_feature_layer < 0) {
max_feature_layer = n_layer;
}
LOG_INF("Number of feature layers: %d\n", max_feature_layer);
// HACK - hold 4 vectors to stack
std::vector<struct ggml_tensor *> embeddingStack;
// TODO - n_layer was previously n_layer - 1, probably to use -2 as the feature layer,
// in actuality it probably is a good idea to use that as a default, but otherwise infer
// how deep in the encoder we actually have to go if we set the hparams for the vision feature
// layer...
for (int il = 0; il < n_layer; il++) {
// loop over layers
for (int il = 0; il < max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
LOG_INF("\tLayer %d...\n", il);
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
LOG_INF("Keeping vision feature layer: %d\n", il);
embedding_stack.push_back(embeddings);
break;
}
}
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
@ -863,36 +874,34 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
// Stack embedding feature layers
// HACK - these values might be decremented unncessarily, check hparams layer; maybe this is the int feature layer index?
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
embeddingStack.push_back(embeddings);
LOG_INF("Saving layer %d...\n", il);
break;
}
}
}
// post-layernorm
// TODO - correctly handle last layer with multiple vision feature layers
if (ctx->has_post_norm) {
if (ctx->has_post_norm && max_feature_layer == n_layer) {
LOG_INF("POST NORMALIZING");
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
LOG_INF("Stacking multiple vision feature layers\n");
// Clobber the output embeddings with the saved items in the embedding stack vector
if(embeddingStack.size() > 0) {
embeddings = embeddingStack.at(0);
for(int i=1; i < embeddingStack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embeddingStack.at(i), 0);
// final layer is a vision feature layer
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
if (n_layer == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
LOG_INF("Keeping vision feature layer : %d\n", n_layer);
embedding_stack.push_back(embeddings);
break;
}
}
// If feature layers are explicitly set, stack them (if we have multiple)
if(has_feature_layers && embedding_stack.size() > 0) {
LOG_INF("Stacking vision feature layers : %d\n", n_layer);
embeddings = embedding_stack.at(0);
for(unsigned long i=1; i < embedding_stack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embedding_stack.at(i), 0);
}
}
LOG_INF("Layer loop over - trying to llava project...\n");
// llava projector
@ -1488,13 +1497,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
LOG_INF("Grid pinpoints | max %d | actual %d ", MAX_IMAGE_GRID_PINPOINTS, n);
for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && i < n && pinpoints[i] != 0; ++i) {
LOG_INF("Grid pinpoints | max %d | actual %d ", 64, n);
for (int i = 0; i < 64 && i < n && pinpoints[i] != 0; ++i) {
LOG_INF(" %d ", i);
hparams.image_grid_pinpoints[i] = pinpoints[i];
}
LOG_INF("\n");
if (n < MAX_IMAGE_GRID_PINPOINTS)
if (n < 64)
hparams.image_grid_pinpoints[n] = 0;
} catch (std::runtime_error & /*e*/) {
hparams.image_grid_pinpoints[0]=0;
@ -1518,12 +1527,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// HACK - need to set a good invalid number here; or maybe not, I guess it could just
// be that it's not set in GGUF, we read all numbers as valid, and from this point on,
// -1 is the sad one
for (int i = 0; i < MAX_VISION_FEATURE_LAYERS && i < n && vision_feature_layer[i] != 0; ++i) {
for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) {
hparams.vision_feature_layer[i] = vision_feature_layer[i];
LOG_INF("feature layer %d - %d | ", i, vision_feature_layer[i]);
}
if (n < MAX_IMAGE_GRID_PINPOINTS)
hparams.image_grid_pinpoints[n] = -1;
if (n < 4)
hparams.vision_feature_layer[n] = -1;
} catch (std::runtime_error & /*e*/) {
LOG_INF("VISION FEATURE LAYER RETRIEVAL FAILED");
hparams.vision_feature_layer[0] = -1;
@ -1566,7 +1575,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
LOG_INF("v_image_grid_pinpoints: ");
for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && (hparams.image_grid_pinpoints[i] != 0); ++i) {
for (int i = 0; i < 64 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
}
LOG_INF("\n");
@ -2325,7 +2334,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
if (params.image_grid_pinpoints[0] != 0) {
// "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < MAX_IMAGE_GRID_PINPOINTS && params.image_grid_pinpoints[i] != 0; i+=2) {
for (int i = 0; i < 64 && params.image_grid_pinpoints[i] != 0; i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
}
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);

View file

@ -227,17 +227,37 @@ if has_text_encoder:
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
if has_vision_encoder:
## FIXME Need to pull this out of the overall model config, not just the top one?
# TODO or document that vision_feature_layer can be set here, but it's usually in the
# llava config and not the vision config itself;
# Handle vision feature layers in transformers, where features may be taken
# from layers that are not the last. NOTE - these values can be unsigned...
def get_unsigned_vision_feature_layers(v_hparams):
"""
Determine the vision feature layer(s) for the llava model, which are indices into the
hidden states of the visual encoder. Note that the hidden states array generally takes the
form:
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
so positive feature indices should be offset as n+1 to get the output of encoder block n.
We convert all vision feature layers to unsigned ints so that -1 can be used in the model
as an unset value. If no vision feature layer is found, we leave it unset.
"""
num_hidden_layers = v_hparams["num_hidden_layers"]
to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
feature_layers_key = None
# Key used for llava models in transformers
if "vision_feature_layer" in config:
feature_layers = v_hparams["vision_feature_layer"]
feature_layers_key = "vision_feature_layer"
# Key used for llava models in the original format
elif "mm_vision_select_layer" in config:
feature_layers_key = "mm_vision_select_layer"
if feature_layers_key is not None:
feature_layers = config[feature_layers_key]
if isinstance(feature_layers, int):
feature_layers = [feature_layers]
fout.add_array("clip.vision.feature_layer", feature_layers)
return [to_uint(feature_layer) for feature_layer in feature_layers]
if has_vision_encoder:
feature_layers = get_unsigned_vision_feature_layers(v_hparams)
# Siglip does not have a visual projector; set projection dim to 0
if args.clip_model_is_siglip:
@ -245,7 +265,7 @@ if has_vision_encoder:
else:
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
# vision_model hparams
# set vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
@ -253,7 +273,7 @@ if has_vision_encoder:
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] #- 1 if has_llava_projector else v_hparams["num_hidden_layers"]
block_count = v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
# /**
# "image_grid_pinpoints": [
@ -305,7 +325,8 @@ if has_vision_encoder:
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
if "mm_projector_type" in v_hparams:
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
if feature_layers:
fout.add_array("clip.vision.feature_layer", feature_layers)
if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]