llama : add Qwen2VL support + multimodal RoPE (#10361)
* Barebone Qwen2VL LLM convertor * Add Qwen2VL cli entrypoint * [WIP] add qwen2vl arch * Verify m-rope output * Add vl-rope/2d-rope support for qwen2vl ViT * update qwen2vl cli tool * update 5D tensor op workaround * [WIP] qwen2vl vision model * make batch and clip utils compatible with qwen2vl * [WIP] create inference workflow, gguf convert script but fix * correcting vision-rope behavior, add the missing last layer back to ViT * add arg parser to qwen2vl_surgery * replace variable size array with vector * cuda-gdb cmake preset * add fp32 mrope, vision rope kernel * add fp16 support for qwen2vl and m-rope * add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION` * fix rope op mode switching, out dated func args * update `llama_hparams` * update to keep up stream changes * resolve linter, test errors * add makefile entry, update speical image padding token * add mrope unit test, fix few compiler warnings * rename `mrope` related function, params * minor updates on debug util, bug fixs * add `m-rope` testcase to `test-backend-ops` * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix traililng whitespce * store `llama_hparams.rope_sections` with fixed size array * update position id tensor size check in GGML_OP_ROPE * minor updates * update `ggml_backend_*_supports_op` of unsupported backends * remote old `rope_section` compare operator --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
56eea0781c
commit
ba1cb19cdd
24 changed files with 1873 additions and 114 deletions
|
@ -102,7 +102,9 @@ static std::string format(const char * fmt, ...) {
|
|||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_USE_SILU "clip.use_silu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
|
@ -129,7 +131,8 @@ static std::string format(const char * fmt, ...) {
|
|||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||
#define TN_PATCH_BIAS "v.patch_embd.bias"
|
||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
|
@ -163,6 +166,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -171,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
};
|
||||
|
||||
|
||||
|
@ -463,7 +468,8 @@ struct clip_vision_model {
|
|||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
struct ggml_tensor * patch_embeddings;
|
||||
struct ggml_tensor * patch_embeddings_0;
|
||||
struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
||||
struct ggml_tensor * patch_bias;
|
||||
struct ggml_tensor * position_embeddings;
|
||||
|
||||
|
@ -553,6 +559,7 @@ struct clip_ctx {
|
|||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
bool has_minicpmv_projector = false;
|
||||
bool has_qwen2vl_merger = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
struct clip_vision_model vision_model;
|
||||
|
@ -561,6 +568,7 @@ struct clip_ctx {
|
|||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
bool has_class_embedding = true;
|
||||
|
@ -606,14 +614,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
// use the image's native resolution when image is avaible
|
||||
if (is_inf) {
|
||||
// if (imgs->data->nx && imgs->data->ny) {
|
||||
image_size_width = imgs->data->nx;
|
||||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int patches_w = image_size_width / patch_size;
|
||||
const int patches_h = image_size_height / patch_size;
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
|
@ -634,10 +654,30 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
|
||||
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||
inp = ggml_reshape_3d(
|
||||
ctx0, inp,
|
||||
hidden_size, patches_w * patches_h, batch_size);
|
||||
}
|
||||
else {
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
}
|
||||
|
||||
if (ctx->has_patch_bias) {
|
||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
|
@ -659,12 +699,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
}
|
||||
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
|
@ -688,7 +730,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
|
||||
// TODO: figure out why we doing thing in this way ???
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
|
@ -710,8 +753,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
Q = ggml_rope_multi(
|
||||
ctx0, Q, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
|
@ -719,6 +767,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
K = ggml_rope_multi(
|
||||
ctx0, K, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
|
@ -758,6 +811,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else if (ctx->use_silu) {
|
||||
cur = ggml_silu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
|
@ -769,6 +824,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
|
@ -1030,6 +1086,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
// GELU activation
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
@ -1206,6 +1275,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
|
||||
if (idx != -1) {
|
||||
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
|
||||
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
|
@ -1214,6 +1287,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_SILU);
|
||||
new_clip->use_silu = gguf_get_val_bool(ctx, idx);
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
new_clip->use_silu = false;
|
||||
}
|
||||
|
||||
if (verbosity >= 1) {
|
||||
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
|
@ -1389,11 +1469,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
} catch(const std::exception& /*e*/) {
|
||||
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
|
||||
}
|
||||
try {
|
||||
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
new_clip->has_qwen2vl_merger = false;
|
||||
}
|
||||
|
||||
// LLaVA projection
|
||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
|
@ -1481,6 +1566,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
|
||||
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
@ -1519,6 +1610,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
|
||||
clip_image_f32_batch batch;
|
||||
batch.size = 1;
|
||||
batch.data = nullptr;
|
||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
||||
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
||||
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
||||
|
@ -1532,6 +1624,10 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size
|
|||
ctx_clip->load_image_size = load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
|
||||
return ctx_clip->load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_image_size_init() {
|
||||
struct clip_image_size * load_image_size = new struct clip_image_size();
|
||||
load_image_size->width = 448;
|
||||
|
@ -1984,6 +2080,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
|||
}
|
||||
return true;
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
clip_image_u8 * resized = clip_image_u8_init();
|
||||
auto patch_size = clip_patch_size(ctx) * 2;
|
||||
int nx = ceil((float)img->nx / patch_size) * patch_size;
|
||||
int ny = ceil((float)img->ny / patch_size) * patch_size;
|
||||
bicubic_resize(*img, *resized, nx, ny);
|
||||
|
||||
res_imgs->data = new clip_image_f32[1];
|
||||
// clip_image_f32 * res = clip_image_f32_init();
|
||||
normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
|
||||
// res_imgs->data[0] = *res;
|
||||
res_imgs->size = 1;
|
||||
|
||||
// clip_image_f32_free(res);
|
||||
clip_image_u8_free(resized);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pad_to_square = true;
|
||||
if (!ctx->has_vision_encoder) {
|
||||
|
@ -2173,6 +2286,13 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
|||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
|
||||
clip_image_f32 img;
|
||||
img.nx = img_w;
|
||||
img.ny = img_h;
|
||||
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
|
@ -2194,6 +2314,13 @@ const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
|||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
clip_image_f32 img;
|
||||
img.nx = ctx->vision_model.hparams.image_size;
|
||||
img.ny = ctx->vision_model.hparams.image_size;
|
||||
return clip_n_patches_by_img(ctx, &img);
|
||||
}
|
||||
|
||||
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->vision_model.hparams;
|
||||
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
|
@ -2207,6 +2334,11 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
else if (ctx->minicpmv_version == 3) {
|
||||
n_patches = 64;
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
n_patches = x_patch * y_patch;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
|
@ -2335,7 +2467,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
|
||||
image_size_width = imgs->data[0].nx;
|
||||
image_size_height = imgs->data[0].ny;
|
||||
}
|
||||
|
@ -2355,7 +2487,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
for (size_t i = 0; i < imgs->size; i++) {
|
||||
const int nx = imgs->data[i].nx;
|
||||
const int ny = imgs->data[i].ny;
|
||||
if (!ctx->has_minicpmv_projector) {
|
||||
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
|
||||
|
@ -2413,9 +2545,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||
|
||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||
for(int i=0;i<pos_w * pos_h;++i){
|
||||
for(int j=0;j<embed_dim;++j){
|
||||
pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j];
|
||||
for(int i=0;i < pos_w * pos_h; ++i){
|
||||
for(int j=0; j < embed_dim; ++j){
|
||||
pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2435,7 +2567,34 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
const int pw = image_size_width / patch_size;
|
||||
const int ph = image_size_height / patch_size;
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
||||
int ptr = 0;
|
||||
for (int y = 0; y < ph; y+=2)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=2)
|
||||
{
|
||||
for (int dy = 0; dy < 2; dy++) {
|
||||
for (int dx = 0; dx < 2; dx++) {
|
||||
positions_data[ptr] = y + dy;
|
||||
positions_data[num_patches + ptr] = x + dx;
|
||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
@ -2444,16 +2603,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2626,6 +2785,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
@ -2637,3 +2799,21 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
|
|||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->has_qwen2vl_merger;
|
||||
}
|
||||
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
for (int i = 0; i < h*w*3; i++)
|
||||
{
|
||||
clip_img.buf[i] = img[i];
|
||||
}
|
||||
clip_img.nx = w;
|
||||
clip_img.ny = h;
|
||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue