push partial code here

This commit is contained in:
Yutong Dai 2024-07-25 21:15:07 +00:00
parent bd2d24aa0d
commit fde72a9f86
2 changed files with 513 additions and 3 deletions

File diff suppressed because one or more lines are too long

View file

@ -143,7 +143,7 @@ if __name__ == "__main__":
elif args.xgenmm_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_xgenmm_projector = False
has_xgenmm_projector = True
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
@ -189,9 +189,13 @@ if __name__ == "__main__":
fout.add_uint32("clip.vision.projection_dim", 0)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vision_config["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), vision_config["layer_norm_eps"])
block_count = vision_config["num_hidden_layers"] - 1 if has_xgenmm_projector else vision_config["num_hidden_layers"]
# TODO: chekck this as it might causes bugs
# orginial llaval implementation:
# block_count = vision_config["num_hidden_layers"] - 1 if has_xgenmm_projector else vision_config["num_hidden_layers"]
# we are different from llama1.6, which used the second to the last layer's hidden states as the image features.
block_count = vision_config["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
print(KEY_BLOCK_COUNT)
# xgenmm use anyres with grids configuration
# 1*2, 2*1, 2*2, 3*1, 1*3, the same as the llava1.6, we just hard code it here
image_grid_pinpoints = [336, 672, 672, 336, 672, 672, 1008, 336, 336, 1008]
@ -206,6 +210,20 @@ if __name__ == "__main__":
# TODO: need to check; vision_config["hidden_act"] is gelu_pytorch_tanh
use_gelu = "gelu" in vision_config["hidden_act"].lower()
fout.add_bool("clip.use_gelu", use_gelu)
if has_xgenmm_projector:
projector = torch.load(args.xgenmm_projector)
fout.add_uint32("clip.projector.input_dim", projector["input_dim"])
fout.add_uint32("clip.projector.output_dim", projector["output_dim"])
fout.add_uint32("clip.projector.num_heads", projector["num_heads"])
fout.add_uint32("clip.projector.num_layers", projector["num_layers"])
fout.add_uint32("clip.projector.hidden_dim", projector["hidden_dim"])
fout.add_float32("clip.projector.dropout", projector["dropout"])
fout.add_string("clip.projector.activation", projector["activation"])
fout.add_string("clip.projector.norm", projector["norm"])
fout.add_string("clip.projector.pooling", projector["pooling"])
fout.add_string("clip.projector.pooling_norm", projector["pooling_norm"])
fout.add_string("clip.projector.pooling_activation", projector["pooling_activation
fout.write_header_to_file()
fout.write_kv_data_to_file()