push partial code here
This commit is contained in:
parent
bd2d24aa0d
commit
fde72a9f86
2 changed files with 513 additions and 3 deletions
492
examples/xgenmm/model_breakdown.ipynb
Normal file
492
examples/xgenmm/model_breakdown.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
|||
elif args.xgenmm_projector is not None:
|
||||
fname_middle = "mmproj-"
|
||||
has_text_encoder = False
|
||||
has_xgenmm_projector = False
|
||||
has_xgenmm_projector = True
|
||||
elif args.vision_only:
|
||||
fname_middle = "vision-"
|
||||
has_text_encoder = False
|
||||
|
@ -189,9 +189,13 @@ if __name__ == "__main__":
|
|||
fout.add_uint32("clip.vision.projection_dim", 0)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vision_config["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), vision_config["layer_norm_eps"])
|
||||
block_count = vision_config["num_hidden_layers"] - 1 if has_xgenmm_projector else vision_config["num_hidden_layers"]
|
||||
# TODO: chekck this as it might causes bugs
|
||||
# orginial llaval implementation:
|
||||
# block_count = vision_config["num_hidden_layers"] - 1 if has_xgenmm_projector else vision_config["num_hidden_layers"]
|
||||
# we are different from llama1.6, which used the second to the last layer's hidden states as the image features.
|
||||
block_count = vision_config["num_hidden_layers"]
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
|
||||
print(KEY_BLOCK_COUNT)
|
||||
# xgenmm use anyres with grids configuration
|
||||
# 1*2, 2*1, 2*2, 3*1, 1*3, the same as the llava1.6, we just hard code it here
|
||||
image_grid_pinpoints = [336, 672, 672, 336, 672, 672, 1008, 336, 336, 1008]
|
||||
|
@ -206,6 +210,20 @@ if __name__ == "__main__":
|
|||
# TODO: need to check; vision_config["hidden_act"] is gelu_pytorch_tanh
|
||||
use_gelu = "gelu" in vision_config["hidden_act"].lower()
|
||||
fout.add_bool("clip.use_gelu", use_gelu)
|
||||
|
||||
if has_xgenmm_projector:
|
||||
projector = torch.load(args.xgenmm_projector)
|
||||
fout.add_uint32("clip.projector.input_dim", projector["input_dim"])
|
||||
fout.add_uint32("clip.projector.output_dim", projector["output_dim"])
|
||||
fout.add_uint32("clip.projector.num_heads", projector["num_heads"])
|
||||
fout.add_uint32("clip.projector.num_layers", projector["num_layers"])
|
||||
fout.add_uint32("clip.projector.hidden_dim", projector["hidden_dim"])
|
||||
fout.add_float32("clip.projector.dropout", projector["dropout"])
|
||||
fout.add_string("clip.projector.activation", projector["activation"])
|
||||
fout.add_string("clip.projector.norm", projector["norm"])
|
||||
fout.add_string("clip.projector.pooling", projector["pooling"])
|
||||
fout.add_string("clip.projector.pooling_norm", projector["pooling_norm"])
|
||||
fout.add_string("clip.projector.pooling_activation", projector["pooling_activation
|
||||
|
||||
fout.write_header_to_file()
|
||||
fout.write_kv_data_to_file()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue