diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7367d44cb..eceffd4ea 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -571,6 +571,23 @@ struct clip_vision_model { struct ggml_tensor * mm_model_ln_kv_b; struct ggml_tensor * mm_model_ln_post_w; struct ggml_tensor * mm_model_ln_post_b; + + // Janus Attention Pool with Latent Query + struct ggml_tensor * attn_pool_latent; + struct ggml_tensor * attn_pool_q_w; + struct ggml_tensor * attn_pool_q_b; + struct ggml_tensor * attn_pool_k_w; + struct ggml_tensor * attn_pool_k_b; + struct ggml_tensor * attn_pool_v_w; + struct ggml_tensor * attn_pool_v_b; + struct ggml_tensor * attn_pool_proj_w; + struct ggml_tensor * attn_pool_proj_b; + struct ggml_tensor * attn_pool_norm_w; + struct ggml_tensor * attn_pool_norm_b; + struct ggml_tensor * attn_pool_ffn_up_w; + struct ggml_tensor * attn_pool_ffn_up_b; + struct ggml_tensor * attn_pool_ffn_down_w; + struct ggml_tensor * attn_pool_ffn_down_b; }; struct clip_ctx { @@ -580,6 +597,7 @@ struct clip_ctx { bool has_minicpmv_projector = false; bool has_glm_projector = false; bool has_qwen2vl_merger = false; + bool has_janus_attn_pool_latent = false; int minicpmv_version = 2; struct clip_vision_model vision_model; @@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + // janus attn pool with latent query + // TODO: Check the ctx0 + else if (ctx->has_janus_attn_pool_latent){ + if (ctx->proj_type == PROJECTOR_TYPE_JANUS) { + struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1] + struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B] + + struct ggml_tensor* Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_q_w, latent_expanded), + model.attn_pool_q_b + ); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, 1, batch_size); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_cont(ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, 1, n_head * batch_size); + + struct ggml_tensor* K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_k_w, embeddings), + model.attn_pool_k_b + ); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor* V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_v_w, embeddings), + model.attn_pool_v_b + ); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor* attn_scores = ggml_mul_mat(ctx0, K, Q); + attn_scores = ggml_soft_max_inplace(ctx0, attn_scores); + + struct ggml_tensor* attn_output = ggml_mul_mat(ctx0, V, attn_scores); + attn_output = ggml_reshape_4d(ctx0, attn_output, d_head, 1, n_head, batch_size); + attn_output = ggml_cont(ggml_permute(ctx0, attn_output, 0, 2, 1, 3)); + attn_output = ggml_cont_3d(ctx0, attn_output, hidden_size, 1, batch_size); + + attn_output = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_proj_w, attn_output), + model.attn_pool_proj_b + ); + + // MLP: fc1 -> gelu -> norm -> fc2 + // References: + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13 + struct ggml_tensor * cur = attn_output; + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_down_w, cur), model.attn_pool_ffn_down_b); + cur = ggml_gelu_inplace(ctx0, cur); + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_up_w, cur), model.attn_pool_ffn_up_b); + // Residual connection + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98 + attn_output = ggml_add(ctx0, attn_output, cur); // [D, 1, B] + + // Pooling, select first token + embeddings = ggml_view_2d(ctx0, + attn_output, + attn_output->ne[0], + attn_output->ne[2], + attn_output->nb[2]); + } else { + GGML_ABORT("fatal error"); + } + } + // build the graph ggml_build_forward_expand(gf, embeddings);