Standardize vision feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex-Brooks 2025-02-10 06:43:49 -07:00
parent 3a191f8edb
commit 188a068a04
2 changed files with 85 additions and 55 deletions

View file

@ -227,17 +227,37 @@ if has_text_encoder:
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
if has_vision_encoder:
## FIXME Need to pull this out of the overall model config, not just the top one?
# TODO or document that vision_feature_layer can be set here, but it's usually in the
# llava config and not the vision config itself;
# Handle vision feature layers in transformers, where features may be taken
# from layers that are not the last. NOTE - these values can be unsigned...
def get_unsigned_vision_feature_layers(v_hparams):
"""
Determine the vision feature layer(s) for the llava model, which are indices into the
hidden states of the visual encoder. Note that the hidden states array generally takes the
form:
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
so positive feature indices should be offset as n+1 to get the output of encoder block n.
We convert all vision feature layers to unsigned ints so that -1 can be used in the model
as an unset value. If no vision feature layer is found, we leave it unset.
"""
num_hidden_layers = v_hparams["num_hidden_layers"]
to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
feature_layers_key = None
# Key used for llava models in transformers
if "vision_feature_layer" in config:
feature_layers = v_hparams["vision_feature_layer"]
feature_layers_key = "vision_feature_layer"
# Key used for llava models in the original format
elif "mm_vision_select_layer" in config:
feature_layers_key = "mm_vision_select_layer"
if feature_layers_key is not None:
feature_layers = config[feature_layers_key]
if isinstance(feature_layers, int):
feature_layers = [feature_layers]
fout.add_array("clip.vision.feature_layer", feature_layers)
return [to_uint(feature_layer) for feature_layer in feature_layers]
if has_vision_encoder:
feature_layers = get_unsigned_vision_feature_layers(v_hparams)
# Siglip does not have a visual projector; set projection dim to 0
if args.clip_model_is_siglip:
@ -245,7 +265,7 @@ if has_vision_encoder:
else:
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
# vision_model hparams
# set vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
@ -253,7 +273,7 @@ if has_vision_encoder:
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] #- 1 if has_llava_projector else v_hparams["num_hidden_layers"]
block_count = v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
# /**
# "image_grid_pinpoints": [
@ -305,7 +325,8 @@ if has_vision_encoder:
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
if "mm_projector_type" in v_hparams:
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
if feature_layers:
fout.add_array("clip.vision.feature_layer", feature_layers)
if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]