diff --git a/examples/xgenmm/clip.cpp b/examples/xgenmm/clip.cpp index 0680f4f41..c620609e2 100644 --- a/examples/xgenmm/clip.cpp +++ b/examples/xgenmm/clip.cpp @@ -627,10 +627,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } - if (ctx->has_xgenmm_projector) { - //TODO: implement something for example, image masks - printf(" use has_xgenmm_projector\n"); - } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); @@ -677,15 +673,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); } } - // printf(" after ctx->has_llava_projector\n"); struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); ggml_set_name(positions, "positions"); ggml_set_input(positions); - // printf("hi2!"); embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); - // printf("hi3!"); if (ctx->has_minicpmv_projector) { int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; @@ -1044,123 +1037,174 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } - // xgenmm-projector - else if (ctx->has_xgenmm_projector) - { - if (ctx->proj_type == PROJECTOR_TYPE_PERCEIVER_RESAMPLER) - { - struct ggml_tensor * self_latents = model.mm_model_latents; - struct ggml_tensor *img_embeddings = embeddings; - // FIXME: hard coded for now - int n_layer = 6; - const float scale = model.mm_model_layers[0].scale; - const int num_head = model.mm_model_layers[0].heads; - const int dim_head = model.mm_model_layers[0].dim_head; - const int q_len = self_latents->ne[1]; - const int kv_len = img_embeddings->ne[1] + self_latents->ne[1]; // concat img_embeddings and latents - const int hidden_size = dim_head * num_head; - // TODO: repeat for (batch_size, n_query_tokens, dim) - ggml_tensor *latents = self_latents; - for (int il = 0; il < n_layer; ++il) - { - struct ggml_tensor *residual = latents; - auto &layer = model.mm_model_layers[il]; - // layer norm + // build the graph + ggml_build_forward_expand(gf, embeddings); - struct ggml_tensor *img_embeddings_normalized = ggml_norm(ctx0, img_embeddings, eps); - img_embeddings_normalized = - ggml_add(ctx0, ggml_mul(ctx0, img_embeddings_normalized, layer.mm_model_ln_media_w), - layer.mm_model_ln_media_b); + ggml_free(ctx0); - latents = ggml_norm(ctx0, latents, eps); - latents = - ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ln_latents_w), layer.mm_model_ln_latents_b); + return gf; +} - // cross attention - { - struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.mm_model_q_w, latents); - Q = ggml_scale_inplace(ctx0, Q, scale); - struct ggml_tensor *kv_inputs = ggml_concat(ctx0, img_embeddings_normalized, latents, 1); - // if (vision_attn_masks){ - // // printf("vision_attn_masks dim0: %ld, dim1: %ld\n", vision_attn_masks->ne[0], - // // vision_attn_masks->ne[1]); create all one tensor - // const int dim0 = latents->ne[1]; // seq length - // const int dim1 = batch_size; - // struct ggml_tensor *all_one_tensor = ggml_new_tensor_2d(ctx0, latents->type, dim0, dim1); - // ggml_set_name(all_one_tensor, "all_one_tensor"); - // ggml_set_input(all_one_tensor); - // vision_attn_masks = ggml_concat(ctx0, vision_attn_masks, all_one_tensor, 0); - // } - struct ggml_tensor *K = ggml_mul_mat(ctx0, layer.mm_model_k_w, kv_inputs); - struct ggml_tensor *V = ggml_mul_mat(ctx0, layer.mm_model_v_w, kv_inputs); - // permute - Q = ggml_reshape_4d(ctx0, Q, dim_head, num_head, q_len, batch_size); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, dim_head, q_len, num_head * batch_size); +static ggml_cgraph * clip_image_build_graph_vit(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false, ggml_tensor *attn_bias_input = nullptr) { + if (!ctx->has_vision_encoder) { + LOG_TEE("This gguf file seems to have no vision encoder\n"); + return nullptr; + } + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; - K = ggml_reshape_4d(ctx0, K, dim_head, num_head, kv_len, batch_size); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, dim_head, kv_len, num_head * batch_size); + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; - V = ggml_reshape_4d(ctx0, V, dim_head, num_head, kv_len, batch_size); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, kv_len, dim_head, num_head * batch_size); + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + int n_layer = hparams.n_layer; + const float eps = hparams.eps; + const int batch_size = imgs->size; - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + GGML_ASSERT(batch_size == 1); - // Apply vision attention mask here. - // if (vision_attn_masks){ - // } - if (attn_bias_input) - { - KQ = ggml_add(ctx0, KQ, attn_bias_input); - }; + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - // ggml_soft_max_inplace use numerical stable softmax implementation - // ggml_soft_max_inplace(ctx0, KQ) = (sim - sim.amax(dim=-1, - // keepdim=True).detach()).softmax(dim=-1) - KQ = ggml_soft_max_inplace(ctx0, KQ); + if (ctx->has_patch_bias) { + // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + } + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, dim_head, q_len, num_head, batch_size); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - KQV = ggml_cont_3d(ctx0, KQV, hidden_size, q_len, batch_size); + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); - latents = ggml_mul_mat(ctx0, layer.mm_model_o_w, KQV); - } - // residual connection + embeddings = + ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); - latents = ggml_add(ctx0, latents, residual); - residual = latents; // update residual - - // FFN - { - // layer norm - latents = ggml_norm(ctx0, latents, eps); - latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ffn_ln_w), layer.mm_model_ffn_ln_b); - // feed forward - latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_up_w, latents); - latents = ggml_gelu_inplace(ctx0, latents); - latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_down_w, latents); - } - - // residual connection - latents = ggml_add(ctx0, latents, residual); - } - - // post layer norm - latents = ggml_norm(ctx0, latents, eps); - latents = ggml_add(ctx0, ggml_mul(ctx0, latents, model.mm_model_norm_w), model.mm_model_norm_b); - latents = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_projection_w, latents), model.mm_model_projection_b); - embeddings = latents; + if (ctx->has_minicpmv_projector) { + int pos_w = image_size_width/patch_size; + int pos_h = image_size_height/patch_size; + if (ctx->minicpmv_version == 2) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); } - else - { - GGML_ASSERT(false); + else if (ctx->minicpmv_version == 3) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + ggml_set_name(pos_embed, "pos_embed"); + ggml_set_input(pos_embed); + } + // pre-layernorm + if (ctx->has_pre_norm) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + } + // loop over layers + if (ctx->has_minicpmv_projector) { + n_layer += 1; + } + for (int il = 0; il < n_layer - 1; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), + model.layers[il].ln_1_b); + } + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + if (ctx->use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + // post-layernorm + if (ctx->has_post_norm) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } // build the graph @@ -1171,6 +1215,126 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 return gf; } +static ggml_cgraph *clip_build_graph_xgenmm_projector(clip_ctx *ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input = nullptr) +{ + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + // const float eps = hparams.eps; // double check this value + const float eps = 1e-5; + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + LOG_TEE("%s: ctx->buf_compute_meta.size(): %zu \n", __func__, ctx->buf_compute_meta.size()); + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // computation starts here + struct ggml_tensor * self_latents = model.mm_model_latents; + + // FIXME: hard coded for now + int n_layer = 6; + const float scale = model.mm_model_layers[0].scale; + const int num_head = model.mm_model_layers[0].heads; + const int dim_head = model.mm_model_layers[0].dim_head; + const int q_len = self_latents->ne[1]; + const int kv_len = img_embeddings->ne[1] + self_latents->ne[1]; // concat img_embeddings and latents + const int hidden_size = dim_head * num_head; + + ggml_tensor *latents = self_latents; + ggml_tensor *latents_repeat_along_batch = ggml_new_tensor_3d(ctx0, latents->type, latents->ne[0], latents->ne[1], batch_size); + latents = ggml_repeat(ctx0, latents, latents_repeat_along_batch); + + ggml_tensor *ans; + for (int il = 0; il < n_layer; ++il) + { + struct ggml_tensor * residual = latents; + auto & layer = model.mm_model_layers[il]; + // layer norm + + struct ggml_tensor *img_embeddings_normalized = ggml_norm(ctx0, img_embeddings, eps); + img_embeddings_normalized = ggml_add( + ctx0, ggml_mul(ctx0, img_embeddings_normalized, layer.mm_model_ln_media_w), layer.mm_model_ln_media_b); + + latents = ggml_norm(ctx0, latents, eps); + latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ln_latents_w), + layer.mm_model_ln_latents_b); + + //cross attention + { + struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.mm_model_q_w, latents); + Q = ggml_scale_inplace(ctx0, Q, scale); + struct ggml_tensor *kv_inputs = ggml_concat(ctx0, img_embeddings_normalized, latents, 1); + struct ggml_tensor * K = ggml_mul_mat(ctx0, layer.mm_model_k_w, kv_inputs); + struct ggml_tensor * V = ggml_mul_mat(ctx0, layer.mm_model_v_w, kv_inputs); + // permute + Q = ggml_reshape_4d(ctx0, Q, dim_head, num_head, q_len, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, dim_head, q_len, num_head * batch_size); + + K = ggml_reshape_4d(ctx0, K, dim_head, num_head, kv_len, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, dim_head, kv_len, num_head * batch_size); + + V = ggml_reshape_4d(ctx0, V, dim_head, num_head, kv_len, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, kv_len, dim_head, num_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // Apply vision attention mask here. + if (attn_bias_input){ + KQ = ggml_cont(ctx0, ggml_reshape_4d(ctx0, KQ, kv_len, q_len, num_head, batch_size)); + attn_bias_input = ggml_cont(ctx0, ggml_reshape_4d(ctx0, attn_bias_input, kv_len, q_len, 1, batch_size)); + + KQ = ggml_add(ctx0, KQ, attn_bias_input); + + KQ = ggml_cont(ctx0, ggml_reshape_3d(ctx0, KQ, kv_len, q_len, num_head * batch_size)); + }; + KQ = ggml_soft_max_inplace(ctx0, KQ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, dim_head, q_len, num_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, q_len, batch_size); + + latents = ggml_mul_mat(ctx0, layer.mm_model_o_w, KQV); + } + // residual connection + + latents = ggml_add(ctx0, latents, residual); + residual = latents; // update residual + + // FFN + { + // layer norm + latents = ggml_norm(ctx0, latents, eps); + latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ffn_ln_w), + layer.mm_model_ffn_ln_b); + // feed forward + latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_up_w, latents); + latents = ggml_gelu_inplace(ctx0, latents); + latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_down_w, latents); + } + + // residual connection + latents = ggml_add(ctx0, latents, residual); + } + + // post layer norm + latents = ggml_norm(ctx0, latents, eps); + latents = ggml_add(ctx0, ggml_mul(ctx0, latents, model.mm_model_norm_w), model.mm_model_norm_b); + latents = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_projection_w, latents), model.mm_model_projection_b); + ans = latents; // 512 * 30xx + ggml_build_forward_expand(gf, ans); + + ggml_free(ctx0); + return gf; +} + // read and create ggml_context containing the tensors and their data struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { struct ggml_context * meta = NULL; @@ -1272,7 +1436,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { { int idx = gguf_find_key(ctx, KEY_PROJ_TYPE); if (idx != -1) { - const std::string proj_type = gguf_get_val_str(ctx, idx); // CT: assign projector name + const std::string proj_type = gguf_get_val_str(ctx, idx); new_clip->proj_type = clip_projector_type_from_string(proj_type); } else { new_clip->proj_type = PROJECTOR_TYPE_MLP; @@ -1329,7 +1493,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } - idx = gguf_find_key(ctx, KEY_HAS_XGENMM_PROJ); // CT: checked. + idx = gguf_find_key(ctx, KEY_HAS_XGENMM_PROJ); if (idx != -1) { new_clip->has_xgenmm_projector = gguf_get_val_bool(ctx, idx); } @@ -1696,10 +1860,8 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size struct clip_image_size * clip_image_size_init() { struct clip_image_size * load_image_size = new struct clip_image_size(); - // load_image_size->width = 448; // CT: this part is hard coded, need check - // load_image_size->height = 448; - load_image_size->width = 384; // CT: this part is hard coded, need check - load_image_size->height = 384; + load_image_size->width = 448; + load_image_size->height = 448; return load_image_size; } @@ -2369,7 +2531,8 @@ int clip_n_patches(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { n_patches /= 4; - } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->minicpmv_version == 2) { n_patches = 96; } @@ -2377,6 +2540,9 @@ int clip_n_patches(const struct clip_ctx * ctx) { n_patches = 64; } } + else if (ctx->proj_type == PROJECTOR_TYPE_PERCEIVER_RESAMPLER){ + n_patches = 128; + } return n_patches; } @@ -2479,6 +2645,47 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 return clip_image_batch_encode(ctx, n_threads, &imgs, vec); } +bool clip_image_encode_vit(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { + if (!ctx->has_vision_encoder) { + LOG_TEE("This gguf file seems to have no vision encoder\n"); + return false; + } + + clip_image_f32_batch imgs{}; + imgs.size = 1; + imgs.data = img; + return clip_image_batch_encode_vit(ctx, n_threads, &imgs, vec); +} + +// bool clip_image_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd) { +// if (!ctx->has_vision_encoder) { +// LOG_TEE("This gguf file seems to have no vision encoder\n"); +// return false; +// } + +// // no batch encode for now +// return clip_image_batch_encode_tokenizer(ctx, n_threads, image_embd_v_m, image_embd_v_m_mask, image_embd); +// } + +bool clip_image_encode_tokenizer(struct clip_ctx * ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input, float * image_embd) { + if (!ctx->has_vision_encoder) { + LOG_TEE("This gguf file seems to have no vision encoder\n"); + return false; + } + ggml_cgraph *gf = clip_build_graph_xgenmm_projector(ctx, batch_size, img_embeddings, attn_bias_input); + ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + ggml_backend_graph_compute(ctx->backend, gf); + struct ggml_tensor * llm_inputs = gf->nodes[gf->n_nodes - 1]; + ggml_backend_tensor_get(llm_inputs, image_embd, 0, ggml_nbytes(llm_inputs)); + clip_free(ctx); + // ggml_free(tensor.ctx); + // if (ctx0){ + // ggml_free(ctx0); + // } + return true; +} + + bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); @@ -2512,7 +2719,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if(ctx->load_image_size==nullptr){ ctx->load_image_size= clip_image_size_init(); } - ctx->load_image_size= clip_image_size_init(); // CT: hard code const int pos_w = ctx->load_image_size->width/patch_size; const int pos_h = ctx->load_image_size->height/patch_size; { @@ -2611,19 +2817,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); } - // FIXEME: this is a hack; - // { - // std::cout << __LINE__ << std::endl; - // struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - // std::cout << __LINE__ << std::endl; - // int* patches_data = (int*)malloc(ggml_nbytes(patches)); - // std::cout << __LINE__ << std::endl; - // for (int i = 0; i < num_patches; i++) { - // patches_data[i] = i + 1; - // } - // ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - // free(patches_data); - // } } if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); @@ -2642,6 +2835,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return true; } + bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { ggml_type type = GGML_TYPE_Q4_1; @@ -2817,3 +3011,108 @@ int clip_is_xgenmm(const struct clip_ctx * ctx) { } return 0; } + + +// separate image encoding logic + +bool clip_image_batch_encode_vit(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { + if (!ctx->has_xgenmm_projector) { + LOG_TEE("Separate image encoding process is only for xgenmm now.\n"); + return false; + } + if (!ctx->has_vision_encoder) { + LOG_TEE("This gguf file seems to have no vision encoder\n"); + return false; + } + + int batch_size = imgs->size; + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + + // build the inference graph + ggml_cgraph * gf = clip_image_build_graph_vit(ctx, imgs, ctx->load_image_size, true); + + ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); + // set inputs + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + if(ctx->load_image_size==nullptr){ + ctx->load_image_size= clip_image_size_init(); + } + struct clip_image_size * load_image_size = new struct clip_image_size(); + load_image_size->width = image_size_width; + load_image_size->height = image_size_height; + ctx->load_image_size = load_image_size; + const int pos_w = ctx->load_image_size->width/patch_size; + const int pos_h = ctx->load_image_size->height/patch_size; + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + float * data = (float *)malloc(ggml_nbytes(inp_raw)); + + for (size_t i = 0; i < imgs->size; i++) { + const int nx = imgs->data[i].nx; + const int ny = imgs->data[i].ny; + if (!ctx->has_minicpmv_projector) { + GGML_ASSERT(nx == image_size && ny == image_size); + } + + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); + free(data); + } + + // copy from minicpm implementation for positional embedding. + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + int bucket_coords_h[70]; + int bucket_coords_w[70]; + for (int i = 0; i < pos_h; i++){ + bucket_coords_h[i] = std::floor(70.0*i/pos_h); + } + for (int i = 0; i < pos_w; i++){ + bucket_coords_w[i] = std::floor(70.0*i/pos_w); + } + for (int i = 0, id = 0; i < pos_h; i++){ + for (int j = 0; j < pos_w; j++){ + positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + + + + if (ggml_backend_is_cpu(ctx->backend)) { + ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(ctx->backend)) { + ggml_backend_metal_set_n_cb(ctx->backend, n_threads); + } +#endif + ggml_backend_graph_compute(ctx->backend, gf); + // the last node is the embedding tensor + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; + ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + return true; +} \ No newline at end of file diff --git a/examples/xgenmm/clip.h b/examples/xgenmm/clip.h index 1c7812cfa..6daf14f50 100644 --- a/examples/xgenmm/clip.h +++ b/examples/xgenmm/clip.h @@ -85,8 +85,12 @@ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_ima CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); +CLIP_API bool clip_image_encode_vit (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); +// CLIP_API bool clip_image_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd); +CLIP_API bool clip_image_encode_tokenizer(struct clip_ctx * ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input, float * image_embd); CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); - +CLIP_API bool clip_image_batch_encode_vit(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); +CLIP_API bool clip_image_batch_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd); CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); diff --git a/examples/xgenmm/xgenmm-cli.cpp b/examples/xgenmm/xgenmm-cli.cpp index fec2545bc..f894c81a0 100644 --- a/examples/xgenmm/xgenmm-cli.cpp +++ b/examples/xgenmm/xgenmm-cli.cpp @@ -181,41 +181,41 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, return ret.c_str(); } -// static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ -// auto ctx_clip = clip_init_context(params); -// auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str()); -// if (!embeds) { -// std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; -// return NULL; -// } +static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ + auto ctx_clip = clip_init_context(params); + auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str()); + if (!embeds) { + std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; + return NULL; + } -// // process the prompt -// if (params->prompt.empty() && params->interactive == false) { -// LOG_TEE("prompt should be given or interactive mode should be on"); -// return NULL; -// } + // process the prompt + if (params->prompt.empty() && params->interactive == false) { + LOG_TEE("prompt should be given or interactive mode should be on"); + return NULL; + } -// auto model = llava_init(params); -// if (model == NULL) { -// fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); -// return NULL; -// } -// const int64_t t_llava_init_start_us = ggml_time_us(); -// auto ctx_llava = llava_init_context(params, model); -// ctx_llava->ctx_clip = ctx_clip; -// const int64_t t_llava_init_end_us = ggml_time_us(); -// float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; -// LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); + auto model = llava_init(params); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); + return NULL; + } + const int64_t t_llava_init_start_us = ggml_time_us(); + auto ctx_llava = llava_init_context(params, model); + ctx_llava->ctx_clip = ctx_clip; + const int64_t t_llava_init_end_us = ggml_time_us(); + float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; + LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); -// const int64_t t_process_image_start_us = ggml_time_us(); -// process_image(ctx_llava, embeds, params, n_past); -// const int64_t t_process_image_end_us = ggml_time_us(); -// float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; -// LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); + const int64_t t_process_image_start_us = ggml_time_us(); + process_image(ctx_llava, embeds, params, n_past); + const int64_t t_process_image_end_us = ggml_time_us(); + float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; + LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); -// llava_image_embed_free(embeds); -// return ctx_llava; -// } + llava_image_embed_free(embeds); + return ctx_llava; +} static struct llava_context * xgenmm_init(gpt_params * params, const std::string & fname, int &n_past){ auto ctx_clip = clip_init_context(params); @@ -226,8 +226,9 @@ static struct llava_context * xgenmm_init(gpt_params * params, const std::string std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; return NULL; } - std::cout<< "Start Processing Prompt" << std::endl; + std::cout<< "Start Processing Prompt: " << std::endl; exit(1); + // TODO: // process the prompt if (params->prompt.empty() && params->interactive == false) { LOG_TEE("prompt should be given or interactive mode should be on"); diff --git a/examples/xgenmm/xgenmm.cpp b/examples/xgenmm/xgenmm.cpp index 00878d974..51e3d43d0 100644 --- a/examples/xgenmm/xgenmm.cpp +++ b/examples/xgenmm/xgenmm.cpp @@ -234,6 +234,246 @@ static bool clip_llava_handle_patches(clip_ctx *ctx_clip, std::vector & return true; } +static bool clip_xgenmm_handle_vit_patches(clip_ctx *ctx_clip , const clip_image_u8 *img , std::vector &image_embd_v, + struct clip_image_grid_shape grid_shape, float * image_embd) + // float * image_embd: final output +{ + int original_width = img->nx; + int original_height = img->ny; + int num_images = image_embd_v.size(); + int32_t num_patches_per_side = clip_image_size(ctx_clip) / clip_patch_size(ctx_clip); + int num_patches_width = grid_shape.first; + int num_patches_height = grid_shape.second; + int patch_num = num_patches_per_side * num_patches_per_side; // 728 + int hidden_size = clip_hidden_size(ctx_clip); // 1152 + size_t size_ele = ggml_type_size(GGML_TYPE_F32); + + struct + { + struct ggml_context* ctx; + } model; + + // TODO: size calculation is not calculated - it's only tens of MB + size_t ctx_size = 0; + + { + ctx_size += + num_patches_per_side * num_patches_per_side * hidden_size * sizeof(float) * num_images * 8; // image_features + ctx_size += 1024 * 1024 * ggml_type_size(GGML_TYPE_F32); + } + struct ggml_init_params params + { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, // NOTE: this should be false when using the legacy API + }; + + model.ctx = ggml_init(params); + + + + struct ggml_tensor* image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, hidden_size, patch_num, num_images - 1); + struct ggml_tensor* base_image_feature = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, hidden_size, patch_num, 1); + + int dim0 = num_images - 1; + int dim1 = num_patches_per_side * num_patches_per_side; + int dim2 = hidden_size; + float* image_features_data = (float*)image_features->data; + float* base_image_feature_data = (float*)base_image_feature->data; + + for (int i=0; i < dim0; i++) + { + if (i==0) + { + // base_image_feature_data + float* image_embd = image_embd_v[i]; + for (int j=0; j < dim1; j++) + { + for (int k=0; k < dim2; k++) + { + base_image_feature_data[j * dim2 + k] = image_embd[j * dim2 + k]; + } + } + } + else + { + // other sub-images + float* image_embd = image_embd_v[i+1]; + for (int j=0; j < dim1; j++) + { + for (int k=0; k < dim2; k++) + { + image_features_data[i * dim1 * dim2 + j * dim2 + k] = image_embd[j * dim2 + k]; + } + } + } + } + + + struct ggml_tensor* image_features_patchview = ggml_view_4d( + model.ctx, image_features, num_patches_per_side * hidden_size, num_patches_per_side, + num_patches_width, num_patches_height, size_ele * num_patches_per_side * hidden_size, + size_ele * num_patches_per_side * hidden_size * num_patches_per_side, + size_ele * num_patches_per_side * hidden_size * num_patches_per_side * num_patches_width, 0); + + struct ggml_tensor* permuted_cont = + ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); + + struct ggml_tensor* flatten = + ggml_view_2d(model.ctx, permuted_cont, hidden_size, + num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, + size_ele * hidden_size, 0); + + struct ggml_tensor* tensor_3d = + ggml_reshape_3d(model.ctx, flatten, + hidden_size, + num_patches_per_side * num_patches_per_side, + num_patches_width * num_patches_height); + tensor_3d = ggml_cont(model.ctx, tensor_3d); + tensor_3d = ggml_concat(model.ctx, base_image_feature, tensor_3d, 2); + struct ggml_cgraph* gf = ggml_new_graph(model.ctx); + ggml_build_forward_expand(gf, tensor_3d); + ggml_graph_compute_with_ctx(model.ctx, gf, 1); + struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1]; + + struct + { + struct ggml_context* ctx; + } mask; + + ctx_size = 0; + + { + ctx_size += + num_patches_per_side * num_patches_width * num_patches_per_side * num_patches_height * sizeof(float) * 4; + ctx_size += 1024 * 1024 * ggml_type_size(GGML_TYPE_F32); + } + + params = + { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, // NOTE: this should be false when using the legacy API + }; + + mask.ctx = ggml_init(params); + int current_height = num_patches_per_side * num_patches_height; + int current_width = num_patches_per_side * num_patches_width; + + float original_aspect_ratio = (float)original_width / (float)original_height; + float current_aspect_ratio = (float)current_width / (float)current_height; + // printf("original_height: %d, original_width: %d, original_aspect_ratio: %.2f\n", original_height, original_width, + // original_aspect_ratio); + // printf("current_height: %d, current_width: %d, current_aspect_ratio: %.2f\n", current_height, current_width, + // current_aspect_ratio); + float scale_factor = 1.0; + struct ggml_tensor* attention_mask = ggml_new_tensor_2d(mask.ctx, GGML_TYPE_F32, current_width, current_height); + float* attention_mask_data = (float*)attention_mask->data; + if (original_aspect_ratio > current_aspect_ratio){ + scale_factor = (float)current_width / (float)original_width; + int new_height = int(original_height * scale_factor); + int padding = (current_height - new_height) / 2; + // printf("new_height: %d, padding: %d\n", new_height, padding); + + for (int i = 0; i < current_height; i++){ + for (int j = 0; j < current_width; j++){ + if (i < padding || i >= current_height - padding) + { + attention_mask_data[i * current_width + j] = 0.0; + } + else + { + attention_mask_data[i * current_width + j] = 1.0; + } + } + } + }else{ + scale_factor = (float)current_height / (float)original_height; + int new_width = int(original_width * scale_factor); + int padding = (current_width - new_width) / 2; + printf("new_width: %d, padding: %d\n", new_width, padding); + for (int i = 0; i < current_height; i++){ + for (int j = 0; j < current_width; j++){ + if (j < padding || j >= current_width - padding) + { + attention_mask_data[i * current_width + j] = 0.0; + } + else + { + attention_mask_data[i * current_width + j] = 1.0; + } + } + } + } + + attention_mask = ggml_reshape_2d(mask.ctx, attention_mask, num_patches_per_side * num_patches_per_side, num_patches_width * num_patches_height); + attention_mask = ggml_cont(mask.ctx, attention_mask); + struct ggml_tensor* all_one_tensor = + ggml_new_tensor_2d(mask.ctx, GGML_TYPE_F32, num_patches_per_side * num_patches_per_side, 1); + std::fill_n((float*)all_one_tensor->data, num_patches_per_side * num_patches_per_side, 1.0); + attention_mask = ggml_concat(mask.ctx, all_one_tensor, attention_mask, 1); + + gf = ggml_new_graph(mask.ctx); + ggml_build_forward_expand(gf, attention_mask); + ggml_graph_compute_with_ctx(mask.ctx, gf, 1); + attention_mask = gf->nodes[gf->n_nodes - 1]; + // memcpy(image_embd_v_m_mask_out, (float *)attention_mask->data, ggml_nbytes(attention_mask)); + + // compute attnetion masks outside of the graph + struct ggml_tensor * attn_bias_input; + struct ggml_context * ctx0; + if (attention_mask) + { + const int ctx_size = 1024 * 1024 * 1024; + struct ggml_init_params params + { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, // NOTE: this should be false when using the legacy API + }; + ctx0 = ggml_init(params); + // vision_attn_mask + // 1 -> 0 + // 0 -> -inf + const int batch_size = attention_mask->ne[1]; + const int vision_seq_length = attention_mask->ne[0]; + for (int i = 0; i < batch_size * vision_seq_length; i++) + { + if (((float *)attention_mask->data)[i] == 1.0) + { + ((float *)attention_mask->data)[i] = 0.0; + } + else + { + ((float *)attention_mask->data)[i] = -INFINITY; + } + } + const int lantents_seq_length = 128; // lantents_seq_length + struct ggml_tensor *all_zero_tensor = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lantents_seq_length, batch_size); + std::fill_n((float *)all_zero_tensor->data, lantents_seq_length * batch_size, 0.0); + + + attention_mask = ggml_concat(ctx0, attention_mask, all_zero_tensor, 0); + ggml_tensor *attn_bias = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, lantents_seq_length + vision_seq_length, + batch_size, lantents_seq_length); + attn_bias = ggml_repeat(ctx0, attention_mask, attn_bias); + attn_bias = ggml_cont(ctx0, ggml_permute(ctx0, attn_bias, 0, 2, 1, 3)); + + struct ggml_cgraph *gf_temp = ggml_new_graph(ctx0); + ggml_build_forward_expand(gf_temp, attn_bias); + ggml_graph_compute_with_ctx(ctx0, gf_temp, 1); + attn_bias_input = attn_bias; + } + int batch_size = num_patches_width * num_patches_height + 1; + const bool encoded = clip_image_encode_tokenizer( + ctx_clip, batch_size, result, attn_bias_input, image_embd); + + ggml_free(model.ctx); + ggml_free(mask.ctx); + return true; +} + + static clip_image_f32 *only_v2_5_reshape_by_patch(clip_image_f32 *image, int patch_size) { int width = image->nx; @@ -343,37 +583,41 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip } else if (clip_is_xgenmm(ctx_clip)) { - // spatial_unpad llava-1.6 type embedding - // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a - // solution to quickly get batching working + // Get image embedding right after VIT, merge before v tokenizer + int n_img_pos_out = 0; // # of output visual token std::vector image_embd_v; image_embd_v.resize(img_res_v.size); for (size_t i = 0; i < img_res_v.size; i++) - { + { + n_img_pos_out += clip_n_patches(ctx_clip); + + // size_t allocated_size = clip_embd_nbytes(ctx_clip); + const int vit_patch_num = clip_image_size(ctx_clip) / clip_patch_size(ctx_clip) * (clip_image_size(ctx_clip) / clip_patch_size(ctx_clip)); image_embd_v[i] = - (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode( + (float *)malloc(vit_patch_num * clip_hidden_size(ctx_clip)* sizeof(float)); // If vit only, it should be 729 * 1152 * 4 = 3359232 + const bool encoded = clip_image_encode_vit( ctx_clip, n_threads, &img_res_v.data[i], - image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + image_embd_v[i]); if (!encoded) { LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1, (int)img_res_v.size); return false; } - for (int j = 0; j < 5; j++) - { - printf(" %.4f ", image_embd_v[i][j]); - } - printf("\n"); + // for (int j = 0; j < 5; j++) + // { + // printf(" %.4f ", image_embd_v[i][j]); + // } + // printf("\n"); } + + *n_img_pos = n_img_pos_out; const int64_t t_img_enc_batch_us = ggml_time_us(); LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - const int32_t *image_grid = clip_image_grid(ctx_clip); - - std::vector> grid_pinpoints; + + std::vector> grid_pinpoints; //(384, 768) (768, 384) (768, 768) (1152, 384) (384, 1152).. for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { grid_pinpoints.push_back({image_grid[i], image_grid[i + 1]}); @@ -385,24 +629,17 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip img_res_v.data = nullptr; const int32_t image_size = clip_image_size(ctx_clip); - struct clip_image_grid_shape grid_shape = - get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size); - - int n_img_pos_out; - clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); - *n_img_pos = n_img_pos_out; + get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size); // grid_shape.first is width (e.g., 3), grid_shape.second is height (e.g., 1) + + // patch merging + projection + clip_xgenmm_handle_vit_patches(ctx_clip, img, image_embd_v, grid_shape, image_embd); for (size_t i = 0; i < image_embd_v.size(); i++) { free(image_embd_v[i]); } image_embd_v.clear(); - - // debug image/segment/normalization content: - // clip_image_u8 * tmp = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*image_feature, *tmp); - // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { @@ -512,7 +749,6 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip // clip_image_convert_f32_to_u8(*image_feature, *tmp); // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); } - std::cout << __LINE__ << std::endl; LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); const int64_t t_img_enc_end_us = ggml_time_us(); @@ -548,6 +784,10 @@ bool llava_image_embed_make_with_clip_img(clip_ctx *ctx_clip, int n_threads, con { num_max_patches = 10; } + else if (clip_is_xgenmm(ctx_clip)) + { + num_max_patches = 10; + } float *image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip) * num_max_patches); // TODO: base on gridsize/llava model if (!image_embd)