From b7fafb7f2aec35c7e2386a785bd1040f1891adad Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Tue, 4 Feb 2025 06:40:38 +0000 Subject: [PATCH 1/3] Add script to convert Janus encoder to GGUF format and update requirements --- .../llava/convert_janus_encoder_to_gguf.py | 299 ++++++++++++++++++ examples/llava/requirements.txt | 1 + 2 files changed, 300 insertions(+) create mode 100644 examples/llava/convert_janus_encoder_to_gguf.py diff --git a/examples/llava/convert_janus_encoder_to_gguf.py b/examples/llava/convert_janus_encoder_to_gguf.py new file mode 100644 index 000000000..d8678c1d5 --- /dev/null +++ b/examples/llava/convert_janus_encoder_to_gguf.py @@ -0,0 +1,299 @@ +import argparse +import os +import json +import re + +import torch +import numpy as np +from gguf import * +from janus.models.clip_encoder import CLIPVisionTower + + +TEXT = "clip.text" +VISION = "clip.vision" + + +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) + + +def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: + if name in ( + "logit_scale", + "text_model.embeddings.position_ids", + "vision_model.embeddings.position_ids", + ): + return True + + if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]: + return True + + if name.startswith("v") and not has_vision: + return True + + if name.startswith("t") and not has_text: + return True + + return False + + +def get_tensor_name(name: str) -> str: + if "projection" in name: + return name + if "mm_projector" in name: + name = name.replace("model.mm_projector", "mm") + name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) + name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + return name + + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) +ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") +ap.add_argument("--clip-model-is-vision", action="store_true", required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") +ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") +ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 +# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 +# TODO: Double check these two values +default_image_mean = [0.48145466, 0.4578275, 0.40821073] +default_image_std = [0.26862954, 0.26130258, 0.27577711] +ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) +ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) + +# with proper +args = ap.parse_args() + + +if args.use_f32: + print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + +# output in the same directory as the model if output_dir is None +dir_model = args.model_dir + +vocab = None +tokens = None + +# Copied from https://huggingface.co/deepseek-ai/Janus-Pro-7B/blob/main/config.json +# This config is used to initialize the `CLIPVisionTower` class +vision_config = { + "image_size":384, + "model_name": "siglip_large_patch16_384", + "select_feature": "same", + "select_layer": -1 +} +# Copied from https://github.com/deepseek-ai/Janus/blob/main/janus/models/siglip_vit.py +# This config is used to initialize the `vision_tower` in `CLIPVisionTower` class +model_config={ + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, +} + +model = CLIPVisionTower(**vision_config) +model.load_state_dict(torch.load(args.model_dir + "/vision_model.pytorch.bin")) +# Merge the two configs +v_hparams = {**vision_config, **model_config} +t_hparams = None + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if args.use_f32: + ftype = 0 + +fname_middle = None +has_text_encoder = False +has_vision_encoder = True +has_llava_projector = False + +fname_middle = "" + +output_dir = args.output_dir if args.output_dir is not None else dir_model +os.makedirs(output_dir, exist_ok=True) +output_prefix = os.path.basename(output_dir).replace("ggml_", "") +fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") +fout = GGUFWriter(path=fname_out, arch="clip") + +fout.add_bool("clip.has_text_encoder", has_text_encoder) +fout.add_bool("clip.has_vision_encoder", has_vision_encoder) +fout.add_bool("clip.has_llava_projector", has_llava_projector) +fout.add_file_type(ftype) +model_name = model_config["model_name"] if "model_name" in model_config else os.path.basename(dir_model) +fout.add_name(model_name) +# TODO: Add more information in the description +fout.add_description("vision-only CLIP model") + +if has_vision_encoder: + # vision_model hparams + fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) + fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["width"]) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["width"] * v_hparams["mlp_ratio"]) + fout.add_uint32("clip.vision.projection_dim", model.vision_tower.patch_embed.proj.out_channels) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["heads"]) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), model.vision_tower.blocks[0].norm1.eps) + block_count = v_hparams['layers'] - 1 if has_llava_projector else v_hparams['layers'] + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) + # /** + # "image_grid_pinpoints": [ + # [ + # 336, + # 672 + # ], + # [ + # 672, + # 336 + # ], + # [ + # 672, + # 672 + # ], + # [ + # 1008, + # 336 + # ], + # [ + # 336, + # 1008 + # ] + # ], + # Flattened: + # [ + # 336, 672, + # 672, 336, + # 672, 672, + # 1008, 336, + # 336, 1008 + # ] + # * + # */ + if "image_grid_pinpoints" in v_hparams: + # flatten it + image_grid_pinpoints = [] + for pinpoint in v_hparams["image_grid_pinpoints"]: + for p in pinpoint: + image_grid_pinpoints.append(p) + fout.add_array("clip.vision.image_grid_pinpoints", image_grid_pinpoints) + if "image_crop_resolution" in v_hparams: + fout.add_uint32("clip.vision.image_crop_resolution", v_hparams["image_crop_resolution"]) + if "image_aspect_ratio" in v_hparams: + fout.add_string("clip.vision.image_aspect_ratio", v_hparams["image_aspect_ratio"]) + if "image_split_resolution" in v_hparams: + fout.add_uint32("clip.vision.image_split_resolution", v_hparams["image_split_resolution"]) + if "mm_patch_merge_type" in v_hparams: + fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"]) + if "mm_projector_type" in v_hparams: + fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) + + + + image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_std = args.image_std if args.image_std is not None else default_image_std + fout.add_array("clip.vision.image_mean", image_mean) + fout.add_array("clip.vision.image_std", image_std) + +use_gelu = True +fout.add_bool("clip.use_gelu", use_gelu) + + +if has_llava_projector: + model.vision_model.encoder.layers.pop(-1) + projector = torch.load(args.llava_projector) + for name, data in projector.items(): + name = get_tensor_name(name) + # pw and dw conv ndim==4 + if data.ndim == 2 or data.ndim == 4: + data = data.squeeze().numpy().astype(np.float16) + else: + data = data.squeeze().numpy().astype(np.float32) + + fout.add_tensor(name, data) + + print("Projector tensors added\n") + +state_dict = model.state_dict() +for name, data in state_dict.items(): + if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): + # we don't need this + print(f"skipping parameter: {name}") + continue + + name = get_tensor_name(name) + data = data.squeeze().numpy() + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if n_dims == 4: + print(f"tensor {name} is always saved in f16") + data = data.astype(np.float16) + ftype_cur = 1 + elif ftype == 1: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + fout.add_tensor(name, data) + + +fout.write_header_to_file() +fout.write_kv_data_to_file() +fout.write_tensors_to_file() +fout.close() + +print("Done. Output file: " + fname_out) diff --git a/examples/llava/requirements.txt b/examples/llava/requirements.txt index cbcbf26c9..19d46bd7c 100644 --- a/examples/llava/requirements.txt +++ b/examples/llava/requirements.txt @@ -3,3 +3,4 @@ pillow~=10.2.0 torch~=2.2.1 torchvision~=0.17.1 +janus @ git+https://github.com/deepseek-ai/Janus.git@main \ No newline at end of file From 3667a0a4a3274b874b4a290c63693eb9d34e4311 Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Wed, 5 Feb 2025 20:42:35 +0000 Subject: [PATCH 2/3] Add example clip cli and enhance tensor name processing in Janus converter --- examples/llava/clip-cli.cpp | 118 ++++++++++++++++++ .../llava/convert_janus_encoder_to_gguf.py | 99 +++++++++------ 2 files changed, 182 insertions(+), 35 deletions(-) create mode 100644 examples/llava/clip-cli.cpp diff --git a/examples/llava/clip-cli.cpp b/examples/llava/clip-cli.cpp new file mode 100644 index 000000000..6f40c5116 --- /dev/null +++ b/examples/llava/clip-cli.cpp @@ -0,0 +1,118 @@ +// +// Example usage of just the vision encoder (CLIP) part of the LLAVA codebase. +// It loads a CLIP model (gguf file) and an image file, +// computes the image embedding, and prints out (a few elements of) the embedding. +// +// Build and run (for example): +// ./bin/llama-clip-cli -c model.gguf -i input.png --threads 1 --verbosity 1 +// ./bin/llama-clip-cli -c clip.gguf -i input.png --threads 1 --verbosity 1 + +#include "arg.h" +#include "base64.hpp" +#include "log.h" +#include "common.h" +#include "clip.h" +#include "llava.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include + +// Structure to hold our command line parameters. +struct vision_params { + std::string clip_model; // Path to the CLIP model file (gguf) + std::string image_file; // Path to the image file to process + int n_threads = 1; // Number of CPU threads to use + int verbosity = 1; // Verbosity level for model loading +}; + +static void print_usage(const char* progname) { + LOG("\nUsage: %s -c -i [--threads ] [--verbosity ]\n\n", progname); +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + vision_params params; + + // Simple command line parsing + if (argc < 5) { + print_usage(argv[0]); + return 1; + } + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-c" || arg == "--clip") { + if (i + 1 < argc) { + params.clip_model = argv[++i]; + } else { + print_usage(argv[0]); + return 1; + } + } else if (arg == "-i" || arg == "--image") { + if (i + 1 < argc) { + params.image_file = argv[++i]; + } else { + print_usage(argv[0]); + return 1; + } + } else if (arg == "--threads") { + if (i + 1 < argc) { + params.n_threads = std::atoi(argv[++i]); + } else { + print_usage(argv[0]); + return 1; + } + } else if (arg == "--verbosity") { + if (i + 1 < argc) { + params.verbosity = std::atoi(argv[++i]); + } else { + print_usage(argv[0]); + return 1; + } + } else { + // Unknown argument. + print_usage(argv[0]); + return 1; + } + } + + if (params.clip_model.empty() || params.image_file.empty()) { + print_usage(argv[0]); + return 1; + } + + // Load the CLIP model. + struct clip_ctx * ctx_clip = clip_model_load(params.clip_model.c_str(), params.verbosity); + if (!ctx_clip) { + LOG_ERR("Failed to load clip model from %s\n", params.clip_model.c_str()); + return 1; + } + LOG_INF("Clip model loaded from %s\n", params.clip_model.c_str()); + + // Load and process the image. + llava_image_embed * embed = llava_image_embed_make_with_filename(ctx_clip, params.n_threads, params.image_file.c_str()); + if (!embed) { + LOG_ERR("Failed to load or process image from %s\n", params.image_file.c_str()); + clip_free(ctx_clip); + return 1; + } + LOG_INF("Image loaded and processed from %s\n", params.image_file.c_str()); + LOG_INF("Image embedding computed with %d positions.\n", embed->n_image_pos); + int print_count = (embed->n_image_pos < 10 ? embed->n_image_pos : 10); + LOG_INF("First %d elements: ", print_count); + + for (int i = 0; i < print_count; i++) { + LOG_INF("%f ", embed->embed[i]); + } + LOG_INF("\n"); + + llava_image_embed_free(embed); + clip_free(ctx_clip); + + return 0; +} diff --git a/examples/llava/convert_janus_encoder_to_gguf.py b/examples/llava/convert_janus_encoder_to_gguf.py index d8678c1d5..bae67f283 100644 --- a/examples/llava/convert_janus_encoder_to_gguf.py +++ b/examples/llava/convert_janus_encoder_to_gguf.py @@ -37,17 +37,64 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b return False -def get_tensor_name(name: str) -> str: - if "projection" in name: - return name - if "mm_projector" in name: - name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) - return name +def get_tensor_name_from_janus(name: str) -> str: + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)$', r'v.blk.\1.attn_qkv.\2',name) + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.norm1\.(.*)$', r'v.blk.\1.ln1.\2', name) + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.attn\.proj\.(.*)$', r'v.blk.\1.attn_out.\2', name) + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.norm2\.(.*)$', r'v.blk.\1.ln2.\2', name) + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.mlp\.fc1\.(.*)$', r'v.blk.\1.ffn_down.\2', name) + name = re.sub(r'^vision_tower\.blocks\.(\d+)\.mlp\.fc2\.(.*)$', r'v.blk.\1.ffn_up.\2', name) + name = re.sub(r'^vision_tower\.patch_embed\.proj\.(.*)$', r'v.patch_embd.\1', name) + name = re.sub(r'^vision_tower\.pos_embed$', r'v.position_embd.weight', name) + name = re.sub(r'^vision_tower\.norm\.(weight|bias)$', r'v.post_ln.\1', name) + + name = name.replace("vision_tower", "v") + name = name.replace("text_model", "t") + name = name.replace("vision_model", "v") + name = name.replace("encoder.layers", "blk") + name = name.replace("blocks", "blk") + name = name.replace("embeddings.", "") + name = name.replace("_proj", "") + name = name.replace("self_attn.", "attn_") + name = name.replace("layer_norm", "ln") + name = name.replace("layernorm", "ln") + name = name.replace("mlp.fc1", "ffn_down") + name = name.replace("mlp.fc2", "ffn_up") + name = name.replace("embedding", "embd") + name = name.replace("final", "post") + name = name.replace("layrnorm", "ln") + + return name - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") +def process_and_save_tensor(tensor: torch.Tensor, new_name: str, ftype: int, fout) -> None: + """Process a tensor (squeeze, cast dtype, log) and save it to `fout`.""" + data = tensor.squeeze().numpy() + n_dims = len(data.shape) + ftype_str = {0: "f32", 1: "f16"} + + ftype_cur = 0 + if n_dims == 4: + print(f"tensor {new_name} is always saved in f16") + data = data.astype(np.float16) + ftype_cur = 1 + elif ftype == 1: + if new_name.endswith(".weight") and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + print(f"{new_name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + fout.add_tensor(new_name, data) def bytes_to_unicode(): """ @@ -261,35 +308,17 @@ for name, data in state_dict.items(): print(f"skipping parameter: {name}") continue - name = get_tensor_name(name) - data = data.squeeze().numpy() + name = get_tensor_name_from_janus(name) - n_dims = len(data.shape) + # Handle the qkv projection weights and biases + if "qkv" in name: + q_tensor, k_tensor, v_tensor = torch.chunk(data, 3, dim=0) - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if n_dims == 4: - print(f"tensor {name} is always saved in f16") - data = data.astype(np.float16) - ftype_cur = 1 - elif ftype == 1: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 + process_and_save_tensor(q_tensor, name.replace("qkv", "q"), ftype, fout) + process_and_save_tensor(k_tensor, name.replace("qkv", "k"), ftype, fout) + process_and_save_tensor(v_tensor, name.replace("qkv", "v"), ftype, fout) else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") - fout.add_tensor(name, data) - + process_and_save_tensor(data, name, ftype, fout) fout.write_header_to_file() fout.write_kv_data_to_file() From 78507168e93b347adbb2f320adc4871840484de3 Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Fri, 7 Feb 2025 06:04:41 +0000 Subject: [PATCH 3/3] Add Janus Attention Pool with Latent Query support in CLIP model --- examples/llava/clip.cpp | 89 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7367d44cb..eceffd4ea 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -571,6 +571,23 @@ struct clip_vision_model { struct ggml_tensor * mm_model_ln_kv_b; struct ggml_tensor * mm_model_ln_post_w; struct ggml_tensor * mm_model_ln_post_b; + + // Janus Attention Pool with Latent Query + struct ggml_tensor * attn_pool_latent; + struct ggml_tensor * attn_pool_q_w; + struct ggml_tensor * attn_pool_q_b; + struct ggml_tensor * attn_pool_k_w; + struct ggml_tensor * attn_pool_k_b; + struct ggml_tensor * attn_pool_v_w; + struct ggml_tensor * attn_pool_v_b; + struct ggml_tensor * attn_pool_proj_w; + struct ggml_tensor * attn_pool_proj_b; + struct ggml_tensor * attn_pool_norm_w; + struct ggml_tensor * attn_pool_norm_b; + struct ggml_tensor * attn_pool_ffn_up_w; + struct ggml_tensor * attn_pool_ffn_up_b; + struct ggml_tensor * attn_pool_ffn_down_w; + struct ggml_tensor * attn_pool_ffn_down_b; }; struct clip_ctx { @@ -580,6 +597,7 @@ struct clip_ctx { bool has_minicpmv_projector = false; bool has_glm_projector = false; bool has_qwen2vl_merger = false; + bool has_janus_attn_pool_latent = false; int minicpmv_version = 2; struct clip_vision_model vision_model; @@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + // janus attn pool with latent query + // TODO: Check the ctx0 + else if (ctx->has_janus_attn_pool_latent){ + if (ctx->proj_type == PROJECTOR_TYPE_JANUS) { + struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1] + struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B] + + struct ggml_tensor* Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_q_w, latent_expanded), + model.attn_pool_q_b + ); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, 1, batch_size); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_cont(ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, 1, n_head * batch_size); + + struct ggml_tensor* K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_k_w, embeddings), + model.attn_pool_k_b + ); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor* V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_v_w, embeddings), + model.attn_pool_v_b + ); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor* attn_scores = ggml_mul_mat(ctx0, K, Q); + attn_scores = ggml_soft_max_inplace(ctx0, attn_scores); + + struct ggml_tensor* attn_output = ggml_mul_mat(ctx0, V, attn_scores); + attn_output = ggml_reshape_4d(ctx0, attn_output, d_head, 1, n_head, batch_size); + attn_output = ggml_cont(ggml_permute(ctx0, attn_output, 0, 2, 1, 3)); + attn_output = ggml_cont_3d(ctx0, attn_output, hidden_size, 1, batch_size); + + attn_output = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.attn_pool_proj_w, attn_output), + model.attn_pool_proj_b + ); + + // MLP: fc1 -> gelu -> norm -> fc2 + // References: + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13 + struct ggml_tensor * cur = attn_output; + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_down_w, cur), model.attn_pool_ffn_down_b); + cur = ggml_gelu_inplace(ctx0, cur); + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b); + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_up_w, cur), model.attn_pool_ffn_up_b); + // Residual connection + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98 + attn_output = ggml_add(ctx0, attn_output, cur); // [D, 1, B] + + // Pooling, select first token + embeddings = ggml_view_2d(ctx0, + attn_output, + attn_output->ne[0], + attn_output->ne[2], + attn_output->nb[2]); + } else { + GGML_ABORT("fatal error"); + } + } + // build the graph ggml_build_forward_expand(gf, embeddings);