initial integration of VIT+Projector
This commit is contained in:
parent
a81ba75193
commit
f07cb4a73d
4 changed files with 734 additions and 190 deletions
|
@ -627,10 +627,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
}
|
||||
|
||||
if (ctx->has_xgenmm_projector) {
|
||||
//TODO: implement something for example, image masks
|
||||
printf(" use has_xgenmm_projector\n");
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
|
@ -677,15 +673,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||
}
|
||||
}
|
||||
// printf(" after ctx->has_llava_projector\n");
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
// printf("hi2!");
|
||||
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
// printf("hi3!");
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
int pos_h = image_size_height/patch_size;
|
||||
|
@ -1044,123 +1037,174 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
// xgenmm-projector
|
||||
else if (ctx->has_xgenmm_projector)
|
||||
{
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_PERCEIVER_RESAMPLER)
|
||||
{
|
||||
struct ggml_tensor * self_latents = model.mm_model_latents;
|
||||
struct ggml_tensor *img_embeddings = embeddings;
|
||||
// FIXME: hard coded for now
|
||||
int n_layer = 6;
|
||||
const float scale = model.mm_model_layers[0].scale;
|
||||
const int num_head = model.mm_model_layers[0].heads;
|
||||
const int dim_head = model.mm_model_layers[0].dim_head;
|
||||
const int q_len = self_latents->ne[1];
|
||||
const int kv_len = img_embeddings->ne[1] + self_latents->ne[1]; // concat img_embeddings and latents
|
||||
const int hidden_size = dim_head * num_head;
|
||||
// TODO: repeat for (batch_size, n_query_tokens, dim)
|
||||
ggml_tensor *latents = self_latents;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il)
|
||||
{
|
||||
struct ggml_tensor *residual = latents;
|
||||
auto &layer = model.mm_model_layers[il];
|
||||
// layer norm
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
struct ggml_tensor *img_embeddings_normalized = ggml_norm(ctx0, img_embeddings, eps);
|
||||
img_embeddings_normalized =
|
||||
ggml_add(ctx0, ggml_mul(ctx0, img_embeddings_normalized, layer.mm_model_ln_media_w),
|
||||
layer.mm_model_ln_media_b);
|
||||
ggml_free(ctx0);
|
||||
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents =
|
||||
ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ln_latents_w), layer.mm_model_ln_latents_b);
|
||||
return gf;
|
||||
}
|
||||
|
||||
// cross attention
|
||||
{
|
||||
struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.mm_model_q_w, latents);
|
||||
Q = ggml_scale_inplace(ctx0, Q, scale);
|
||||
struct ggml_tensor *kv_inputs = ggml_concat(ctx0, img_embeddings_normalized, latents, 1);
|
||||
// if (vision_attn_masks){
|
||||
// // printf("vision_attn_masks dim0: %ld, dim1: %ld\n", vision_attn_masks->ne[0],
|
||||
// // vision_attn_masks->ne[1]); create all one tensor
|
||||
// const int dim0 = latents->ne[1]; // seq length
|
||||
// const int dim1 = batch_size;
|
||||
// struct ggml_tensor *all_one_tensor = ggml_new_tensor_2d(ctx0, latents->type, dim0, dim1);
|
||||
// ggml_set_name(all_one_tensor, "all_one_tensor");
|
||||
// ggml_set_input(all_one_tensor);
|
||||
|
||||
// vision_attn_masks = ggml_concat(ctx0, vision_attn_masks, all_one_tensor, 0);
|
||||
// }
|
||||
struct ggml_tensor *K = ggml_mul_mat(ctx0, layer.mm_model_k_w, kv_inputs);
|
||||
struct ggml_tensor *V = ggml_mul_mat(ctx0, layer.mm_model_v_w, kv_inputs);
|
||||
// permute
|
||||
Q = ggml_reshape_4d(ctx0, Q, dim_head, num_head, q_len, batch_size);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, dim_head, q_len, num_head * batch_size);
|
||||
static ggml_cgraph * clip_image_build_graph_vit(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false, ggml_tensor *attn_bias_input = nullptr) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
return nullptr;
|
||||
}
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, dim_head, num_head, kv_len, batch_size);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, dim_head, kv_len, num_head * batch_size);
|
||||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
|
||||
V = ggml_reshape_4d(ctx0, V, dim_head, num_head, kv_len, batch_size);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, kv_len, dim_head, num_head * batch_size);
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
|
||||
// Apply vision attention mask here.
|
||||
// if (vision_attn_masks){
|
||||
// }
|
||||
if (attn_bias_input)
|
||||
{
|
||||
KQ = ggml_add(ctx0, KQ, attn_bias_input);
|
||||
};
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
|
||||
// ggml_soft_max_inplace use numerical stable softmax implementation
|
||||
// ggml_soft_max_inplace(ctx0, KQ) = (sim - sim.amax(dim=-1,
|
||||
// keepdim=True).detach()).softmax(dim=-1)
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
if (ctx->has_patch_bias) {
|
||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
}
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
struct ggml_tensor * pos_embed = nullptr;
|
||||
|
||||
struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, dim_head, q_len, num_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, q_len, batch_size);
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_o_w, KQV);
|
||||
}
|
||||
// residual connection
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
|
||||
latents = ggml_add(ctx0, latents, residual);
|
||||
residual = latents; // update residual
|
||||
|
||||
// FFN
|
||||
{
|
||||
// layer norm
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ffn_ln_w), layer.mm_model_ffn_ln_b);
|
||||
// feed forward
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_up_w, latents);
|
||||
latents = ggml_gelu_inplace(ctx0, latents);
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_down_w, latents);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
latents = ggml_add(ctx0, latents, residual);
|
||||
}
|
||||
|
||||
// post layer norm
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents = ggml_add(ctx0, ggml_mul(ctx0, latents, model.mm_model_norm_w), model.mm_model_norm_b);
|
||||
latents =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_projection_w, latents), model.mm_model_projection_b);
|
||||
embeddings = latents;
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
int pos_h = image_size_height/patch_size;
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||
}
|
||||
ggml_set_name(pos_embed, "pos_embed");
|
||||
ggml_set_input(pos_embed);
|
||||
}
|
||||
// pre-layernorm
|
||||
if (ctx->has_pre_norm) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||
}
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
||||
|
||||
// layernorm1
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
|
||||
model.layers[il].ln_1_b);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||
|
||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
||||
}
|
||||
// attention output
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
// layernorm2
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
// post-layernorm
|
||||
if (ctx->has_post_norm) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
|
@ -1171,6 +1215,126 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph *clip_build_graph_xgenmm_projector(clip_ctx *ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input = nullptr)
|
||||
{
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
// const float eps = hparams.eps; // double check this value
|
||||
const float eps = 1e-5;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
LOG_TEE("%s: ctx->buf_compute_meta.size(): %zu \n", __func__, ctx->buf_compute_meta.size());
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// computation starts here
|
||||
struct ggml_tensor * self_latents = model.mm_model_latents;
|
||||
|
||||
// FIXME: hard coded for now
|
||||
int n_layer = 6;
|
||||
const float scale = model.mm_model_layers[0].scale;
|
||||
const int num_head = model.mm_model_layers[0].heads;
|
||||
const int dim_head = model.mm_model_layers[0].dim_head;
|
||||
const int q_len = self_latents->ne[1];
|
||||
const int kv_len = img_embeddings->ne[1] + self_latents->ne[1]; // concat img_embeddings and latents
|
||||
const int hidden_size = dim_head * num_head;
|
||||
|
||||
ggml_tensor *latents = self_latents;
|
||||
ggml_tensor *latents_repeat_along_batch = ggml_new_tensor_3d(ctx0, latents->type, latents->ne[0], latents->ne[1], batch_size);
|
||||
latents = ggml_repeat(ctx0, latents, latents_repeat_along_batch);
|
||||
|
||||
ggml_tensor *ans;
|
||||
for (int il = 0; il < n_layer; ++il)
|
||||
{
|
||||
struct ggml_tensor * residual = latents;
|
||||
auto & layer = model.mm_model_layers[il];
|
||||
// layer norm
|
||||
|
||||
struct ggml_tensor *img_embeddings_normalized = ggml_norm(ctx0, img_embeddings, eps);
|
||||
img_embeddings_normalized = ggml_add(
|
||||
ctx0, ggml_mul(ctx0, img_embeddings_normalized, layer.mm_model_ln_media_w), layer.mm_model_ln_media_b);
|
||||
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ln_latents_w),
|
||||
layer.mm_model_ln_latents_b);
|
||||
|
||||
//cross attention
|
||||
{
|
||||
struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.mm_model_q_w, latents);
|
||||
Q = ggml_scale_inplace(ctx0, Q, scale);
|
||||
struct ggml_tensor *kv_inputs = ggml_concat(ctx0, img_embeddings_normalized, latents, 1);
|
||||
struct ggml_tensor * K = ggml_mul_mat(ctx0, layer.mm_model_k_w, kv_inputs);
|
||||
struct ggml_tensor * V = ggml_mul_mat(ctx0, layer.mm_model_v_w, kv_inputs);
|
||||
// permute
|
||||
Q = ggml_reshape_4d(ctx0, Q, dim_head, num_head, q_len, batch_size);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, dim_head, q_len, num_head * batch_size);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, dim_head, num_head, kv_len, batch_size);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, dim_head, kv_len, num_head * batch_size);
|
||||
|
||||
V = ggml_reshape_4d(ctx0, V, dim_head, num_head, kv_len, batch_size);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, kv_len, dim_head, num_head * batch_size);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// Apply vision attention mask here.
|
||||
if (attn_bias_input){
|
||||
KQ = ggml_cont(ctx0, ggml_reshape_4d(ctx0, KQ, kv_len, q_len, num_head, batch_size));
|
||||
attn_bias_input = ggml_cont(ctx0, ggml_reshape_4d(ctx0, attn_bias_input, kv_len, q_len, 1, batch_size));
|
||||
|
||||
KQ = ggml_add(ctx0, KQ, attn_bias_input);
|
||||
|
||||
KQ = ggml_cont(ctx0, ggml_reshape_3d(ctx0, KQ, kv_len, q_len, num_head * batch_size));
|
||||
};
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, dim_head, q_len, num_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, q_len, batch_size);
|
||||
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_o_w, KQV);
|
||||
}
|
||||
// residual connection
|
||||
|
||||
latents = ggml_add(ctx0, latents, residual);
|
||||
residual = latents; // update residual
|
||||
|
||||
// FFN
|
||||
{
|
||||
// layer norm
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents = ggml_add(ctx0, ggml_mul(ctx0, latents, layer.mm_model_ffn_ln_w),
|
||||
layer.mm_model_ffn_ln_b);
|
||||
// feed forward
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_up_w, latents);
|
||||
latents = ggml_gelu_inplace(ctx0, latents);
|
||||
latents = ggml_mul_mat(ctx0, layer.mm_model_ffn_linear_down_w, latents);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
latents = ggml_add(ctx0, latents, residual);
|
||||
}
|
||||
|
||||
// post layer norm
|
||||
latents = ggml_norm(ctx0, latents, eps);
|
||||
latents = ggml_add(ctx0, ggml_mul(ctx0, latents, model.mm_model_norm_w), model.mm_model_norm_b);
|
||||
latents = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_projection_w, latents), model.mm_model_projection_b);
|
||||
ans = latents; // 512 * 30xx
|
||||
ggml_build_forward_expand(gf, ans);
|
||||
|
||||
ggml_free(ctx0);
|
||||
return gf;
|
||||
}
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
struct ggml_context * meta = NULL;
|
||||
|
@ -1272,7 +1436,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
{
|
||||
int idx = gguf_find_key(ctx, KEY_PROJ_TYPE);
|
||||
if (idx != -1) {
|
||||
const std::string proj_type = gguf_get_val_str(ctx, idx); // CT: assign projector name
|
||||
const std::string proj_type = gguf_get_val_str(ctx, idx);
|
||||
new_clip->proj_type = clip_projector_type_from_string(proj_type);
|
||||
} else {
|
||||
new_clip->proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
@ -1329,7 +1493,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_XGENMM_PROJ); // CT: checked.
|
||||
idx = gguf_find_key(ctx, KEY_HAS_XGENMM_PROJ);
|
||||
if (idx != -1) {
|
||||
new_clip->has_xgenmm_projector = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
|
@ -1696,10 +1860,8 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size
|
|||
|
||||
struct clip_image_size * clip_image_size_init() {
|
||||
struct clip_image_size * load_image_size = new struct clip_image_size();
|
||||
// load_image_size->width = 448; // CT: this part is hard coded, need check
|
||||
// load_image_size->height = 448;
|
||||
load_image_size->width = 384; // CT: this part is hard coded, need check
|
||||
load_image_size->height = 384;
|
||||
load_image_size->width = 448;
|
||||
load_image_size->height = 448;
|
||||
return load_image_size;
|
||||
}
|
||||
|
||||
|
@ -2369,7 +2531,8 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
||||
n_patches /= 4;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
n_patches = 96;
|
||||
}
|
||||
|
@ -2377,6 +2540,9 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
n_patches = 64;
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_PERCEIVER_RESAMPLER){
|
||||
n_patches = 128;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
}
|
||||
|
@ -2479,6 +2645,47 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
|
|||
return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
|
||||
}
|
||||
|
||||
bool clip_image_encode_vit(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
clip_image_f32_batch imgs{};
|
||||
imgs.size = 1;
|
||||
imgs.data = img;
|
||||
return clip_image_batch_encode_vit(ctx, n_threads, &imgs, vec);
|
||||
}
|
||||
|
||||
// bool clip_image_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd) {
|
||||
// if (!ctx->has_vision_encoder) {
|
||||
// LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// // no batch encode for now
|
||||
// return clip_image_batch_encode_tokenizer(ctx, n_threads, image_embd_v_m, image_embd_v_m_mask, image_embd);
|
||||
// }
|
||||
|
||||
bool clip_image_encode_tokenizer(struct clip_ctx * ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input, float * image_embd) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
return false;
|
||||
}
|
||||
ggml_cgraph *gf = clip_build_graph_xgenmm_projector(ctx, batch_size, img_embeddings, attn_bias_input);
|
||||
ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
|
||||
ggml_backend_graph_compute(ctx->backend, gf);
|
||||
struct ggml_tensor * llm_inputs = gf->nodes[gf->n_nodes - 1];
|
||||
ggml_backend_tensor_get(llm_inputs, image_embd, 0, ggml_nbytes(llm_inputs));
|
||||
clip_free(ctx);
|
||||
// ggml_free(tensor.ctx);
|
||||
// if (ctx0){
|
||||
// ggml_free(ctx0);
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
|
@ -2512,7 +2719,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
if(ctx->load_image_size==nullptr){
|
||||
ctx->load_image_size= clip_image_size_init();
|
||||
}
|
||||
ctx->load_image_size= clip_image_size_init(); // CT: hard code
|
||||
const int pos_w = ctx->load_image_size->width/patch_size;
|
||||
const int pos_h = ctx->load_image_size->height/patch_size;
|
||||
{
|
||||
|
@ -2611,19 +2817,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
// FIXEME: this is a hack;
|
||||
// {
|
||||
// std::cout << __LINE__ << std::endl;
|
||||
// struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
// std::cout << __LINE__ << std::endl;
|
||||
// int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
// std::cout << __LINE__ << std::endl;
|
||||
// for (int i = 0; i < num_patches; i++) {
|
||||
// patches_data[i] = i + 1;
|
||||
// }
|
||||
// ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
// free(patches_data);
|
||||
// }
|
||||
}
|
||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||
|
@ -2642,6 +2835,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
|
||||
ggml_type type = GGML_TYPE_Q4_1;
|
||||
|
||||
|
@ -2817,3 +3011,108 @@ int clip_is_xgenmm(const struct clip_ctx * ctx) {
|
|||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
// separate image encoding logic
|
||||
|
||||
bool clip_image_batch_encode_vit(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
|
||||
if (!ctx->has_xgenmm_projector) {
|
||||
LOG_TEE("Separate image encoding process is only for xgenmm now.\n");
|
||||
return false;
|
||||
}
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_TEE("This gguf file seems to have no vision encoder\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
int batch_size = imgs->size;
|
||||
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
|
||||
|
||||
// build the inference graph
|
||||
ggml_cgraph * gf = clip_image_build_graph_vit(ctx, imgs, ctx->load_image_size, true);
|
||||
|
||||
ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
|
||||
// set inputs
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
if(ctx->load_image_size==nullptr){
|
||||
ctx->load_image_size= clip_image_size_init();
|
||||
}
|
||||
struct clip_image_size * load_image_size = new struct clip_image_size();
|
||||
load_image_size->width = image_size_width;
|
||||
load_image_size->height = image_size_height;
|
||||
ctx->load_image_size = load_image_size;
|
||||
const int pos_w = ctx->load_image_size->width/patch_size;
|
||||
const int pos_h = ctx->load_image_size->height/patch_size;
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
for (size_t i = 0; i < imgs->size; i++) {
|
||||
const int nx = imgs->data[i].nx;
|
||||
const int ny = imgs->data[i].ny;
|
||||
if (!ctx->has_minicpmv_projector) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
|
||||
const int n = nx * ny;
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||
free(data);
|
||||
}
|
||||
|
||||
// copy from minicpm implementation for positional embedding.
|
||||
// inspired from siglip:
|
||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
|
||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
int bucket_coords_h[70];
|
||||
int bucket_coords_w[70];
|
||||
for (int i = 0; i < pos_h; i++){
|
||||
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
||||
}
|
||||
for (int i = 0; i < pos_w; i++){
|
||||
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
|
||||
}
|
||||
for (int i = 0, id = 0; i < pos_h; i++){
|
||||
for (int j = 0; j < pos_w; j++){
|
||||
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
|
||||
|
||||
|
||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(ctx->backend)) {
|
||||
ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(ctx->backend, gf);
|
||||
// the last node is the embedding tensor
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
|
||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||
return true;
|
||||
}
|
|
@ -85,8 +85,12 @@ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_ima
|
|||
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
CLIP_API bool clip_image_encode_vit (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
// CLIP_API bool clip_image_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd);
|
||||
CLIP_API bool clip_image_encode_tokenizer(struct clip_ctx * ctx, int batch_size, ggml_tensor *img_embeddings, ggml_tensor *attn_bias_input, float * image_embd);
|
||||
CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||
|
||||
CLIP_API bool clip_image_batch_encode_vit(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||
CLIP_API bool clip_image_batch_encode_tokenizer(struct clip_ctx * ctx, const int n_threads, float * image_embd_v_m, float * image_embd_v_m_mask, float * image_embd);
|
||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
|
|
|
@ -181,41 +181,41 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
|
|||
return ret.c_str();
|
||||
}
|
||||
|
||||
// static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){
|
||||
// auto ctx_clip = clip_init_context(params);
|
||||
// auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str());
|
||||
// if (!embeds) {
|
||||
// std::cerr << "error: failed to load image " << fname << ". Terminating\n\n";
|
||||
// return NULL;
|
||||
// }
|
||||
static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){
|
||||
auto ctx_clip = clip_init_context(params);
|
||||
auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str());
|
||||
if (!embeds) {
|
||||
std::cerr << "error: failed to load image " << fname << ". Terminating\n\n";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// // process the prompt
|
||||
// if (params->prompt.empty() && params->interactive == false) {
|
||||
// LOG_TEE("prompt should be given or interactive mode should be on");
|
||||
// return NULL;
|
||||
// }
|
||||
// process the prompt
|
||||
if (params->prompt.empty() && params->interactive == false) {
|
||||
LOG_TEE("prompt should be given or interactive mode should be on");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// auto model = llava_init(params);
|
||||
// if (model == NULL) {
|
||||
// fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__);
|
||||
// return NULL;
|
||||
// }
|
||||
// const int64_t t_llava_init_start_us = ggml_time_us();
|
||||
// auto ctx_llava = llava_init_context(params, model);
|
||||
// ctx_llava->ctx_clip = ctx_clip;
|
||||
// const int64_t t_llava_init_end_us = ggml_time_us();
|
||||
// float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0;
|
||||
// LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms);
|
||||
auto model = llava_init(params);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
const int64_t t_llava_init_start_us = ggml_time_us();
|
||||
auto ctx_llava = llava_init_context(params, model);
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
const int64_t t_llava_init_end_us = ggml_time_us();
|
||||
float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0;
|
||||
LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms);
|
||||
|
||||
// const int64_t t_process_image_start_us = ggml_time_us();
|
||||
// process_image(ctx_llava, embeds, params, n_past);
|
||||
// const int64_t t_process_image_end_us = ggml_time_us();
|
||||
// float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0;
|
||||
// LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms);
|
||||
const int64_t t_process_image_start_us = ggml_time_us();
|
||||
process_image(ctx_llava, embeds, params, n_past);
|
||||
const int64_t t_process_image_end_us = ggml_time_us();
|
||||
float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0;
|
||||
LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms);
|
||||
|
||||
// llava_image_embed_free(embeds);
|
||||
// return ctx_llava;
|
||||
// }
|
||||
llava_image_embed_free(embeds);
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static struct llava_context * xgenmm_init(gpt_params * params, const std::string & fname, int &n_past){
|
||||
auto ctx_clip = clip_init_context(params);
|
||||
|
@ -226,8 +226,9 @@ static struct llava_context * xgenmm_init(gpt_params * params, const std::string
|
|||
std::cerr << "error: failed to load image " << fname << ". Terminating\n\n";
|
||||
return NULL;
|
||||
}
|
||||
std::cout<< "Start Processing Prompt" << std::endl;
|
||||
std::cout<< "Start Processing Prompt: " << std::endl;
|
||||
exit(1);
|
||||
// TODO:
|
||||
// process the prompt
|
||||
if (params->prompt.empty() && params->interactive == false) {
|
||||
LOG_TEE("prompt should be given or interactive mode should be on");
|
||||
|
|
|
@ -234,6 +234,246 @@ static bool clip_llava_handle_patches(clip_ctx *ctx_clip, std::vector<float *> &
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool clip_xgenmm_handle_vit_patches(clip_ctx *ctx_clip , const clip_image_u8 *img , std::vector<float *> &image_embd_v,
|
||||
struct clip_image_grid_shape grid_shape, float * image_embd)
|
||||
// float * image_embd: final output
|
||||
{
|
||||
int original_width = img->nx;
|
||||
int original_height = img->ny;
|
||||
int num_images = image_embd_v.size();
|
||||
int32_t num_patches_per_side = clip_image_size(ctx_clip) / clip_patch_size(ctx_clip);
|
||||
int num_patches_width = grid_shape.first;
|
||||
int num_patches_height = grid_shape.second;
|
||||
int patch_num = num_patches_per_side * num_patches_per_side; // 728
|
||||
int hidden_size = clip_hidden_size(ctx_clip); // 1152
|
||||
size_t size_ele = ggml_type_size(GGML_TYPE_F32);
|
||||
|
||||
struct
|
||||
{
|
||||
struct ggml_context* ctx;
|
||||
} model;
|
||||
|
||||
// TODO: size calculation is not calculated - it's only tens of MB
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
ctx_size +=
|
||||
num_patches_per_side * num_patches_per_side * hidden_size * sizeof(float) * num_images * 8; // image_features
|
||||
ctx_size += 1024 * 1024 * ggml_type_size(GGML_TYPE_F32);
|
||||
}
|
||||
struct ggml_init_params params
|
||||
{
|
||||
/*.mem_size =*/ctx_size,
|
||||
/*.mem_buffer =*/NULL,
|
||||
/*.no_alloc =*/false, // NOTE: this should be false when using the legacy API
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
||||
|
||||
|
||||
struct ggml_tensor* image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, hidden_size, patch_num, num_images - 1);
|
||||
struct ggml_tensor* base_image_feature = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, hidden_size, patch_num, 1);
|
||||
|
||||
int dim0 = num_images - 1;
|
||||
int dim1 = num_patches_per_side * num_patches_per_side;
|
||||
int dim2 = hidden_size;
|
||||
float* image_features_data = (float*)image_features->data;
|
||||
float* base_image_feature_data = (float*)base_image_feature->data;
|
||||
|
||||
for (int i=0; i < dim0; i++)
|
||||
{
|
||||
if (i==0)
|
||||
{
|
||||
// base_image_feature_data
|
||||
float* image_embd = image_embd_v[i];
|
||||
for (int j=0; j < dim1; j++)
|
||||
{
|
||||
for (int k=0; k < dim2; k++)
|
||||
{
|
||||
base_image_feature_data[j * dim2 + k] = image_embd[j * dim2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// other sub-images
|
||||
float* image_embd = image_embd_v[i+1];
|
||||
for (int j=0; j < dim1; j++)
|
||||
{
|
||||
for (int k=0; k < dim2; k++)
|
||||
{
|
||||
image_features_data[i * dim1 * dim2 + j * dim2 + k] = image_embd[j * dim2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct ggml_tensor* image_features_patchview = ggml_view_4d(
|
||||
model.ctx, image_features, num_patches_per_side * hidden_size, num_patches_per_side,
|
||||
num_patches_width, num_patches_height, size_ele * num_patches_per_side * hidden_size,
|
||||
size_ele * num_patches_per_side * hidden_size * num_patches_per_side,
|
||||
size_ele * num_patches_per_side * hidden_size * num_patches_per_side * num_patches_width, 0);
|
||||
|
||||
struct ggml_tensor* permuted_cont =
|
||||
ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor* flatten =
|
||||
ggml_view_2d(model.ctx, permuted_cont, hidden_size,
|
||||
num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side,
|
||||
size_ele * hidden_size, 0);
|
||||
|
||||
struct ggml_tensor* tensor_3d =
|
||||
ggml_reshape_3d(model.ctx, flatten,
|
||||
hidden_size,
|
||||
num_patches_per_side * num_patches_per_side,
|
||||
num_patches_width * num_patches_height);
|
||||
tensor_3d = ggml_cont(model.ctx, tensor_3d);
|
||||
tensor_3d = ggml_concat(model.ctx, base_image_feature, tensor_3d, 2);
|
||||
struct ggml_cgraph* gf = ggml_new_graph(model.ctx);
|
||||
ggml_build_forward_expand(gf, tensor_3d);
|
||||
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
|
||||
struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
struct
|
||||
{
|
||||
struct ggml_context* ctx;
|
||||
} mask;
|
||||
|
||||
ctx_size = 0;
|
||||
|
||||
{
|
||||
ctx_size +=
|
||||
num_patches_per_side * num_patches_width * num_patches_per_side * num_patches_height * sizeof(float) * 4;
|
||||
ctx_size += 1024 * 1024 * ggml_type_size(GGML_TYPE_F32);
|
||||
}
|
||||
|
||||
params =
|
||||
{
|
||||
/*.mem_size =*/ctx_size,
|
||||
/*.mem_buffer =*/NULL,
|
||||
/*.no_alloc =*/false, // NOTE: this should be false when using the legacy API
|
||||
};
|
||||
|
||||
mask.ctx = ggml_init(params);
|
||||
int current_height = num_patches_per_side * num_patches_height;
|
||||
int current_width = num_patches_per_side * num_patches_width;
|
||||
|
||||
float original_aspect_ratio = (float)original_width / (float)original_height;
|
||||
float current_aspect_ratio = (float)current_width / (float)current_height;
|
||||
// printf("original_height: %d, original_width: %d, original_aspect_ratio: %.2f\n", original_height, original_width,
|
||||
// original_aspect_ratio);
|
||||
// printf("current_height: %d, current_width: %d, current_aspect_ratio: %.2f\n", current_height, current_width,
|
||||
// current_aspect_ratio);
|
||||
float scale_factor = 1.0;
|
||||
struct ggml_tensor* attention_mask = ggml_new_tensor_2d(mask.ctx, GGML_TYPE_F32, current_width, current_height);
|
||||
float* attention_mask_data = (float*)attention_mask->data;
|
||||
if (original_aspect_ratio > current_aspect_ratio){
|
||||
scale_factor = (float)current_width / (float)original_width;
|
||||
int new_height = int(original_height * scale_factor);
|
||||
int padding = (current_height - new_height) / 2;
|
||||
// printf("new_height: %d, padding: %d\n", new_height, padding);
|
||||
|
||||
for (int i = 0; i < current_height; i++){
|
||||
for (int j = 0; j < current_width; j++){
|
||||
if (i < padding || i >= current_height - padding)
|
||||
{
|
||||
attention_mask_data[i * current_width + j] = 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
attention_mask_data[i * current_width + j] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
scale_factor = (float)current_height / (float)original_height;
|
||||
int new_width = int(original_width * scale_factor);
|
||||
int padding = (current_width - new_width) / 2;
|
||||
printf("new_width: %d, padding: %d\n", new_width, padding);
|
||||
for (int i = 0; i < current_height; i++){
|
||||
for (int j = 0; j < current_width; j++){
|
||||
if (j < padding || j >= current_width - padding)
|
||||
{
|
||||
attention_mask_data[i * current_width + j] = 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
attention_mask_data[i * current_width + j] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attention_mask = ggml_reshape_2d(mask.ctx, attention_mask, num_patches_per_side * num_patches_per_side, num_patches_width * num_patches_height);
|
||||
attention_mask = ggml_cont(mask.ctx, attention_mask);
|
||||
struct ggml_tensor* all_one_tensor =
|
||||
ggml_new_tensor_2d(mask.ctx, GGML_TYPE_F32, num_patches_per_side * num_patches_per_side, 1);
|
||||
std::fill_n((float*)all_one_tensor->data, num_patches_per_side * num_patches_per_side, 1.0);
|
||||
attention_mask = ggml_concat(mask.ctx, all_one_tensor, attention_mask, 1);
|
||||
|
||||
gf = ggml_new_graph(mask.ctx);
|
||||
ggml_build_forward_expand(gf, attention_mask);
|
||||
ggml_graph_compute_with_ctx(mask.ctx, gf, 1);
|
||||
attention_mask = gf->nodes[gf->n_nodes - 1];
|
||||
// memcpy(image_embd_v_m_mask_out, (float *)attention_mask->data, ggml_nbytes(attention_mask));
|
||||
|
||||
// compute attnetion masks outside of the graph
|
||||
struct ggml_tensor * attn_bias_input;
|
||||
struct ggml_context * ctx0;
|
||||
if (attention_mask)
|
||||
{
|
||||
const int ctx_size = 1024 * 1024 * 1024;
|
||||
struct ggml_init_params params
|
||||
{
|
||||
/*.mem_size =*/ctx_size,
|
||||
/*.mem_buffer =*/NULL,
|
||||
/*.no_alloc =*/false, // NOTE: this should be false when using the legacy API
|
||||
};
|
||||
ctx0 = ggml_init(params);
|
||||
// vision_attn_mask
|
||||
// 1 -> 0
|
||||
// 0 -> -inf
|
||||
const int batch_size = attention_mask->ne[1];
|
||||
const int vision_seq_length = attention_mask->ne[0];
|
||||
for (int i = 0; i < batch_size * vision_seq_length; i++)
|
||||
{
|
||||
if (((float *)attention_mask->data)[i] == 1.0)
|
||||
{
|
||||
((float *)attention_mask->data)[i] = 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
((float *)attention_mask->data)[i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
const int lantents_seq_length = 128; // lantents_seq_length
|
||||
struct ggml_tensor *all_zero_tensor = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lantents_seq_length, batch_size);
|
||||
std::fill_n((float *)all_zero_tensor->data, lantents_seq_length * batch_size, 0.0);
|
||||
|
||||
|
||||
attention_mask = ggml_concat(ctx0, attention_mask, all_zero_tensor, 0);
|
||||
ggml_tensor *attn_bias = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, lantents_seq_length + vision_seq_length,
|
||||
batch_size, lantents_seq_length);
|
||||
attn_bias = ggml_repeat(ctx0, attention_mask, attn_bias);
|
||||
attn_bias = ggml_cont(ctx0, ggml_permute(ctx0, attn_bias, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_cgraph *gf_temp = ggml_new_graph(ctx0);
|
||||
ggml_build_forward_expand(gf_temp, attn_bias);
|
||||
ggml_graph_compute_with_ctx(ctx0, gf_temp, 1);
|
||||
attn_bias_input = attn_bias;
|
||||
}
|
||||
int batch_size = num_patches_width * num_patches_height + 1;
|
||||
const bool encoded = clip_image_encode_tokenizer(
|
||||
ctx_clip, batch_size, result, attn_bias_input, image_embd);
|
||||
|
||||
ggml_free(model.ctx);
|
||||
ggml_free(mask.ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static clip_image_f32 *only_v2_5_reshape_by_patch(clip_image_f32 *image, int patch_size)
|
||||
{
|
||||
int width = image->nx;
|
||||
|
@ -343,37 +583,41 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip
|
|||
}
|
||||
else if (clip_is_xgenmm(ctx_clip))
|
||||
{
|
||||
// spatial_unpad llava-1.6 type embedding
|
||||
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a
|
||||
// solution to quickly get batching working
|
||||
// Get image embedding right after VIT, merge before v tokenizer
|
||||
int n_img_pos_out = 0; // # of output visual token
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
for (size_t i = 0; i < img_res_v.size; i++)
|
||||
{
|
||||
{
|
||||
n_img_pos_out += clip_n_patches(ctx_clip);
|
||||
|
||||
// size_t allocated_size = clip_embd_nbytes(ctx_clip);
|
||||
const int vit_patch_num = clip_image_size(ctx_clip) / clip_patch_size(ctx_clip) * (clip_image_size(ctx_clip) / clip_patch_size(ctx_clip));
|
||||
image_embd_v[i] =
|
||||
(float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
|
||||
const bool encoded = clip_image_encode(
|
||||
(float *)malloc(vit_patch_num * clip_hidden_size(ctx_clip)* sizeof(float)); // If vit only, it should be 729 * 1152 * 4 = 3359232
|
||||
const bool encoded = clip_image_encode_vit(
|
||||
ctx_clip, n_threads, &img_res_v.data[i],
|
||||
image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
image_embd_v[i]);
|
||||
if (!encoded)
|
||||
{
|
||||
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1,
|
||||
(int)img_res_v.size);
|
||||
return false;
|
||||
}
|
||||
for (int j = 0; j < 5; j++)
|
||||
{
|
||||
printf(" %.4f ", image_embd_v[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
// for (int j = 0; j < 5; j++)
|
||||
// {
|
||||
// printf(" %.4f ", image_embd_v[i][j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
}
|
||||
|
||||
*n_img_pos = n_img_pos_out;
|
||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
||||
LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size,
|
||||
(t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
const int32_t *image_grid = clip_image_grid(ctx_clip);
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints; //(384, 768) (768, 384) (768, 768) (1152, 384) (384, 1152)..
|
||||
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2)
|
||||
{
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i + 1]});
|
||||
|
@ -385,24 +629,17 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip
|
|||
img_res_v.data = nullptr;
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
|
||||
struct clip_image_grid_shape grid_shape =
|
||||
get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size);
|
||||
|
||||
int n_img_pos_out;
|
||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
|
||||
*n_img_pos = n_img_pos_out;
|
||||
get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size); // grid_shape.first is width (e.g., 3), grid_shape.second is height (e.g., 1)
|
||||
|
||||
|
||||
// patch merging + projection
|
||||
clip_xgenmm_handle_vit_patches(ctx_clip, img, image_embd_v, grid_shape, image_embd);
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++)
|
||||
{
|
||||
free(image_embd_v[i]);
|
||||
}
|
||||
image_embd_v.clear();
|
||||
|
||||
// debug image/segment/normalization content:
|
||||
// clip_image_u8 * tmp = clip_image_u8_init();
|
||||
// clip_image_convert_f32_to_u8(*image_feature, *tmp);
|
||||
// clip_image_save_to_bmp(*tmp, "image_feature.bmp");
|
||||
}
|
||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0)
|
||||
{
|
||||
|
@ -512,7 +749,6 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip
|
|||
// clip_image_convert_f32_to_u8(*image_feature, *tmp);
|
||||
// clip_image_save_to_bmp(*tmp, "image_feature.bmp");
|
||||
}
|
||||
std::cout << __LINE__ << std::endl;
|
||||
LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
|
||||
|
||||
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||
|
@ -548,6 +784,10 @@ bool llava_image_embed_make_with_clip_img(clip_ctx *ctx_clip, int n_threads, con
|
|||
{
|
||||
num_max_patches = 10;
|
||||
}
|
||||
else if (clip_is_xgenmm(ctx_clip))
|
||||
{
|
||||
num_max_patches = 10;
|
||||
}
|
||||
float *image_embd =
|
||||
(float *)malloc(clip_embd_nbytes(ctx_clip) * num_max_patches); // TODO: base on gridsize/llava model
|
||||
if (!image_embd)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue