make batch and clip utils compatible with qwen2vl

This commit is contained in:
HimariO 2024-10-18 18:59:47 +08:00
parent c13edfed59
commit 7e9fc7202e
4 changed files with 72 additions and 28 deletions

View file

@ -673,6 +673,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
ctx0, inp,
hidden_size, patches_w * patches_h, batch_size);
// ggml_build_forward_expand(gf, inp);
// ggml_free(ctx0);
// return gf;
}
else {
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
@ -756,7 +760,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
if (ctx->has_qwen2vl_merger) {
Q = ggml_mrope_ext(
ctx0, Q, positions, nullptr,
d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 1000000, 1, 0, 1, 32, 1);
d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 10000, 1, 0, 1, 32, 1);
}
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
@ -769,7 +773,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
if (ctx->has_qwen2vl_merger) {
K = ggml_mrope_ext(
ctx0, K, positions, nullptr,
d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 1000000, 1, 0, 1, 32, 1);
d_head/2, mrope_sections, 2 /*LLAMA_ROPE_TYPE_NEOX8*/, 32768, 10000, 1, 0, 1, 32, 1);
}
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
@ -823,7 +827,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// ggml_build_forward_expand(gf, embeddings);
// ggml_free(ctx0);
// return gf;
// post-layernorm
if (ctx->has_post_norm) {
@ -1623,6 +1632,10 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size
ctx_clip->load_image_size = load_image_size;
}
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
return ctx_clip->load_image_size;
}
struct clip_image_size * clip_image_size_init() {
struct clip_image_size * load_image_size = new struct clip_image_size();
load_image_size->width = 448;
@ -2086,6 +2099,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
// clip_image_f32 * res = clip_image_f32_init();
normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
// res_imgs->data[0] = *res;
res_imgs->size = 1;
// clip_image_f32_free(res);
clip_image_u8_free(resized);
@ -2280,6 +2294,13 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
clip_image_f32 img;
img.nx = img_w;
img.ny = img_h;
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
int32_t clip_image_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_size;
}
@ -2561,26 +2582,17 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int pw = image_size_width / patch_size;
const int ph = image_size_height / patch_size;
int* positions_data = (int*)malloc(ggml_nbytes(positions));
// for (size_t y = 0; y < ph; y++)
// {
// for (size_t x = 0; x < pw; x++)
// {
// positions_data[y * pw + x] = y;
// positions_data[num_patches + (y * pw + x)] = x;
// positions_data[num_patches * 2 + (y * pw + x)] = 0;
// }
// }
int ptr = -1;
for (size_t y = 0; y < ph; y+=2)
{
for (size_t x = 0; x < pw; x+=2)
{
for (size_t dy = 0; y < 2; y++) {
for (size_t dx = 0; x < 2; x++) {
positions_data[ptr++] = y + dy;
positions_data[ptr++] = x + dx;
positions_data[ptr++] = 0;
for (size_t dy = 0; dy < 2; dy++) {
for (size_t dx = 0; dx < 2; dx++) {
positions_data[ptr++] = y + dy;
positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = 0;
}
}
}
@ -2780,6 +2792,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return 3584;
}
}
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@ -2792,6 +2807,10 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
return 0;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}
bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;

View file

@ -45,6 +45,7 @@ CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity
CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
@ -61,6 +62,7 @@ CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
@ -87,6 +89,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
#ifdef __cplusplus

View file

@ -259,25 +259,33 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
if (clip_is_minicpmv(ctx_clip)) {
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
std::vector<float *> image_embd_v;
image_embd_v.resize(img_res_v.size);
struct clip_image_size * load_image_size = clip_image_size_init();
for (size_t i = 0; i < img_res_v.size; i++) {
const int64_t t_img_enc_step_start_us = ggml_time_us();
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip));
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
int patch_size=14;
load_image_size->width = img_res_v.data[i].nx;
load_image_size->height = img_res_v.data[i].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
bool encoded = false;
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
if (clip_is_qwen2vl(ctx_clip)) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
else {
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
}
if (!encoded) {
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false;
@ -290,8 +298,11 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
int n_img_pos_out = 0;
for (size_t i = 0; i < image_embd_v.size(); i++) {
std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip));
n_img_pos_out += clip_n_patches(ctx_clip);
std::memcpy(
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
image_embd_v[i],
clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
}
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
@ -387,7 +398,13 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10;
}
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
float * image_embd;
if (clip_is_qwen2vl(ctx_clip)) {
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny));
} else {
image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
}
if (!image_embd) {
LOG_ERR("Unable to allocate memory for image embeddings\n");
return false;

View file

@ -3333,6 +3333,10 @@ struct llama_context {
// whether we are computing encoder output or decoder output
bool is_encoding = false;
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
@ -16874,6 +16878,7 @@ static struct ggml_cgraph * llama_build_graph(
} break;
case LLM_ARCH_QWEN2VL:
{
lctx.n_pos_per_token = 3;
result = llm.build_qwen2vl();
} break;
case LLM_ARCH_QWEN2MOE:
@ -17098,8 +17103,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
if (ubatch.pos && lctx.inp_pos) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
auto n_pos = lctx.n_pos_per_token;
ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {