Add Janus Attention Pool with Latent Query support in CLIP model
This commit is contained in:
parent
3667a0a4a3
commit
78507168e9
1 changed files with 89 additions and 0 deletions
|
@ -571,6 +571,23 @@ struct clip_vision_model {
|
|||
struct ggml_tensor * mm_model_ln_kv_b;
|
||||
struct ggml_tensor * mm_model_ln_post_w;
|
||||
struct ggml_tensor * mm_model_ln_post_b;
|
||||
|
||||
// Janus Attention Pool with Latent Query
|
||||
struct ggml_tensor * attn_pool_latent;
|
||||
struct ggml_tensor * attn_pool_q_w;
|
||||
struct ggml_tensor * attn_pool_q_b;
|
||||
struct ggml_tensor * attn_pool_k_w;
|
||||
struct ggml_tensor * attn_pool_k_b;
|
||||
struct ggml_tensor * attn_pool_v_w;
|
||||
struct ggml_tensor * attn_pool_v_b;
|
||||
struct ggml_tensor * attn_pool_proj_w;
|
||||
struct ggml_tensor * attn_pool_proj_b;
|
||||
struct ggml_tensor * attn_pool_norm_w;
|
||||
struct ggml_tensor * attn_pool_norm_b;
|
||||
struct ggml_tensor * attn_pool_ffn_up_w;
|
||||
struct ggml_tensor * attn_pool_ffn_up_b;
|
||||
struct ggml_tensor * attn_pool_ffn_down_w;
|
||||
struct ggml_tensor * attn_pool_ffn_down_b;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
|
@ -580,6 +597,7 @@ struct clip_ctx {
|
|||
bool has_minicpmv_projector = false;
|
||||
bool has_glm_projector = false;
|
||||
bool has_qwen2vl_merger = false;
|
||||
bool has_janus_attn_pool_latent = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
struct clip_vision_model vision_model;
|
||||
|
@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
}
|
||||
|
||||
// janus attn pool with latent query
|
||||
// TODO: Check the ctx0
|
||||
else if (ctx->has_janus_attn_pool_latent){
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_JANUS) {
|
||||
struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1]
|
||||
struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B]
|
||||
|
||||
struct ggml_tensor* Q = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.attn_pool_q_w, latent_expanded),
|
||||
model.attn_pool_q_b
|
||||
);
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, 1, batch_size);
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_cont(ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, 1, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor* K = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.attn_pool_k_w, embeddings),
|
||||
model.attn_pool_k_b
|
||||
);
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
K = ggml_cont(ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor* V = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.attn_pool_v_w, embeddings),
|
||||
model.attn_pool_v_b
|
||||
);
|
||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||
V = ggml_cont(ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor* attn_scores = ggml_mul_mat(ctx0, K, Q);
|
||||
attn_scores = ggml_soft_max_inplace(ctx0, attn_scores);
|
||||
|
||||
struct ggml_tensor* attn_output = ggml_mul_mat(ctx0, V, attn_scores);
|
||||
attn_output = ggml_reshape_4d(ctx0, attn_output, d_head, 1, n_head, batch_size);
|
||||
attn_output = ggml_cont(ggml_permute(ctx0, attn_output, 0, 2, 1, 3));
|
||||
attn_output = ggml_cont_3d(ctx0, attn_output, hidden_size, 1, batch_size);
|
||||
|
||||
attn_output = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.attn_pool_proj_w, attn_output),
|
||||
model.attn_pool_proj_b
|
||||
);
|
||||
|
||||
// MLP: fc1 -> gelu -> norm -> fc2
|
||||
// References:
|
||||
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13
|
||||
struct ggml_tensor * cur = attn_output;
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_down_w, cur), model.attn_pool_ffn_down_b);
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_up_w, cur), model.attn_pool_ffn_up_b);
|
||||
// Residual connection
|
||||
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98
|
||||
attn_output = ggml_add(ctx0, attn_output, cur); // [D, 1, B]
|
||||
|
||||
// Pooling, select first token
|
||||
embeddings = ggml_view_2d(ctx0,
|
||||
attn_output,
|
||||
attn_output->ne[0],
|
||||
attn_output->ne[2],
|
||||
attn_output->nb[2]);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue