Add Janus Attention Pool with Latent Query support in CLIP model

This commit is contained in:
ravenouse 2025-02-07 06:04:41 +00:00
parent 3667a0a4a3
commit 78507168e9

View file

@ -571,6 +571,23 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
// Janus Attention Pool with Latent Query
struct ggml_tensor * attn_pool_latent;
struct ggml_tensor * attn_pool_q_w;
struct ggml_tensor * attn_pool_q_b;
struct ggml_tensor * attn_pool_k_w;
struct ggml_tensor * attn_pool_k_b;
struct ggml_tensor * attn_pool_v_w;
struct ggml_tensor * attn_pool_v_b;
struct ggml_tensor * attn_pool_proj_w;
struct ggml_tensor * attn_pool_proj_b;
struct ggml_tensor * attn_pool_norm_w;
struct ggml_tensor * attn_pool_norm_b;
struct ggml_tensor * attn_pool_ffn_up_w;
struct ggml_tensor * attn_pool_ffn_up_b;
struct ggml_tensor * attn_pool_ffn_down_w;
struct ggml_tensor * attn_pool_ffn_down_b;
};
struct clip_ctx {
@ -580,6 +597,7 @@ struct clip_ctx {
bool has_minicpmv_projector = false;
bool has_glm_projector = false;
bool has_qwen2vl_merger = false;
bool has_janus_attn_pool_latent = false;
int minicpmv_version = 2;
struct clip_vision_model vision_model;
@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
}
// janus attn pool with latent query
// TODO: Check the ctx0
else if (ctx->has_janus_attn_pool_latent){
if (ctx->proj_type == PROJECTOR_TYPE_JANUS) {
struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1]
struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B]
struct ggml_tensor* Q = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.attn_pool_q_w, latent_expanded),
model.attn_pool_q_b
);
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, 1, batch_size);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_cont(ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, 1, n_head * batch_size);
struct ggml_tensor* K = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.attn_pool_k_w, embeddings),
model.attn_pool_k_b
);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
K = ggml_cont(ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
struct ggml_tensor* V = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.attn_pool_v_w, embeddings),
model.attn_pool_v_b
);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
V = ggml_cont(ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor* attn_scores = ggml_mul_mat(ctx0, K, Q);
attn_scores = ggml_soft_max_inplace(ctx0, attn_scores);
struct ggml_tensor* attn_output = ggml_mul_mat(ctx0, V, attn_scores);
attn_output = ggml_reshape_4d(ctx0, attn_output, d_head, 1, n_head, batch_size);
attn_output = ggml_cont(ggml_permute(ctx0, attn_output, 0, 2, 1, 3));
attn_output = ggml_cont_3d(ctx0, attn_output, hidden_size, 1, batch_size);
attn_output = ggml_add(ctx0,
ggml_mul_mat(ctx0, model.attn_pool_proj_w, attn_output),
model.attn_pool_proj_b
);
// MLP: fc1 -> gelu -> norm -> fc2
// References:
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13
struct ggml_tensor * cur = attn_output;
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_down_w, cur), model.attn_pool_ffn_down_b);
cur = ggml_gelu_inplace(ctx0, cur);
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_up_w, cur), model.attn_pool_ffn_up_b);
// Residual connection
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98
attn_output = ggml_add(ctx0, attn_output, cur); // [D, 1, B]
// Pooling, select first token
embeddings = ggml_view_2d(ctx0,
attn_output,
attn_output->ne[0],
attn_output->ne[2],
attn_output->nb[2]);
} else {
GGML_ABORT("fatal error");
}
}
// build the graph
ggml_build_forward_expand(gf, embeddings);