Standardize vision feature layers
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
3a191f8edb
commit
188a068a04
2 changed files with 85 additions and 55 deletions
|
@ -227,17 +227,37 @@ if has_text_encoder:
|
|||
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||
fout.add_token_list(tokens)
|
||||
|
||||
if has_vision_encoder:
|
||||
## FIXME Need to pull this out of the overall model config, not just the top one?
|
||||
# TODO or document that vision_feature_layer can be set here, but it's usually in the
|
||||
# llava config and not the vision config itself;
|
||||
# Handle vision feature layers in transformers, where features may be taken
|
||||
# from layers that are not the last. NOTE - these values can be unsigned...
|
||||
|
||||
|
||||
def get_unsigned_vision_feature_layers(v_hparams):
|
||||
"""
|
||||
Determine the vision feature layer(s) for the llava model, which are indices into the
|
||||
hidden states of the visual encoder. Note that the hidden states array generally takes the
|
||||
form:
|
||||
|
||||
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
|
||||
|
||||
so positive feature indices should be offset as n+1 to get the output of encoder block n.
|
||||
We convert all vision feature layers to unsigned ints so that -1 can be used in the model
|
||||
as an unset value. If no vision feature layer is found, we leave it unset.
|
||||
"""
|
||||
num_hidden_layers = v_hparams["num_hidden_layers"]
|
||||
to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
|
||||
feature_layers_key = None
|
||||
# Key used for llava models in transformers
|
||||
if "vision_feature_layer" in config:
|
||||
feature_layers = v_hparams["vision_feature_layer"]
|
||||
feature_layers_key = "vision_feature_layer"
|
||||
# Key used for llava models in the original format
|
||||
elif "mm_vision_select_layer" in config:
|
||||
feature_layers_key = "mm_vision_select_layer"
|
||||
if feature_layers_key is not None:
|
||||
feature_layers = config[feature_layers_key]
|
||||
if isinstance(feature_layers, int):
|
||||
feature_layers = [feature_layers]
|
||||
fout.add_array("clip.vision.feature_layer", feature_layers)
|
||||
return [to_uint(feature_layer) for feature_layer in feature_layers]
|
||||
|
||||
if has_vision_encoder:
|
||||
feature_layers = get_unsigned_vision_feature_layers(v_hparams)
|
||||
|
||||
# Siglip does not have a visual projector; set projection dim to 0
|
||||
if args.clip_model_is_siglip:
|
||||
|
@ -245,7 +265,7 @@ if has_vision_encoder:
|
|||
else:
|
||||
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
|
||||
|
||||
# vision_model hparams
|
||||
# set vision_model hparams
|
||||
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||
|
@ -253,7 +273,7 @@ if has_vision_encoder:
|
|||
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||
block_count = v_hparams["num_hidden_layers"] #- 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||
block_count = v_hparams["num_hidden_layers"]
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
# /**
|
||||
# "image_grid_pinpoints": [
|
||||
|
@ -305,7 +325,8 @@ if has_vision_encoder:
|
|||
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
|
||||
if "mm_projector_type" in v_hparams:
|
||||
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
||||
|
||||
if feature_layers:
|
||||
fout.add_array("clip.vision.feature_layer", feature_layers)
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue