imitate reshape bug of python code

This commit is contained in:
caitianchi 2024-07-04 17:25:02 +08:00
parent 4c67d7cef5
commit 977941d9fe
3 changed files with 50 additions and 16 deletions

View file

@ -410,13 +410,10 @@ void llava_image_embed_free(struct llava_image_embed * embed) {
free(embed);
}
static bool encode_image_with_clip_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
static bool encode_image_with_clip_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos, std::pair<int, int> load_image_size) {
// std::vector<clip_image_f32*> img_res_v;
// format VectN x H x W x RGB (N x 448 x 448 x 3)
clip_image_f32 * img_res_v = clip_image_f32_init();
std::pair<int, int> load_image_size;
load_image_size.first = img->nx;
load_image_size.second = img->ny;
uhd_normalize_image_u8_to_f32(ctx_clip, img, img_res_v);
const int64_t t_img_enc_start_us = ggml_time_us();
@ -545,6 +542,34 @@ static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int tar
return true;
}
static clip_image_u8 * only_v2_5_reshape_by_patch(clip_image_u8 * image, int patch_size) {
int width = image->nx;
int height = image->ny;
int num_patches = (height / patch_size) * (width / patch_size);
clip_image_u8 * patch = clip_image_u8_init();
patch->nx = patch_size * num_patches;
patch->ny = patch_size;
patch->buf.resize(3 * patch->nx * patch->ny);
int patch_index = 0;
for (int i = 0; i < height; i += patch_size) {
for (int j = 0; j < width; j += patch_size) {
for (int pi = 0; pi < patch_size; ++pi) {
for (int pj = 0; pj < patch_size; ++pj) {
int input_index = ((i + pi) * width + (j + pj)) * 3;
int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3;
patch->buf[output_index] = image->buf[input_index];
patch->buf[output_index+1] = image->buf[input_index+1];
patch->buf[output_index+2] = image->buf[input_index+2];
}
}
patch_index++;
}
}
return patch;
}
// inspired from LLaVA-UHD:
// -> https://arxiv.org/pdf/2403.11703
// -> https://github.com/thunlp/LLaVA-UHD
@ -657,7 +682,11 @@ struct uhd_image_embed * llava_image_embed_make_with_bytes_uhd(struct clip_ctx *
for (size_t j = 0; j < imgs[i].size(); ++j) {
float* image_embed = NULL;
int n_image_pos = 0;
bool image_embed_result = llava_image_embed_make_with_clip_img_uhd(ctx_clip, n_threads, imgs[i][j], &image_embed, &n_image_pos);
int patch_size=14;
std::pair<int, int> load_image_size;
load_image_size.first = imgs[i][j]->nx;
load_image_size.second = imgs[i][j]->ny;
bool image_embed_result = llava_image_embed_make_with_clip_img_uhd(ctx_clip, n_threads, only_v2_5_reshape_by_patch(imgs[i][j], patch_size), &image_embed, &n_image_pos, load_image_size);
if (!image_embed_result) {
LOG_TEE("%s: coulnd't embed the image\n", __func__);
return NULL;
@ -672,7 +701,7 @@ struct uhd_image_embed * llava_image_embed_make_with_bytes_uhd(struct clip_ctx *
return results;
}
bool llava_image_embed_make_with_clip_img_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
bool llava_image_embed_make_with_clip_img_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out, std::pair<int, int> load_image_size) {
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6); // TODO: base on gridsize/llava model
if (!image_embd) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
@ -680,7 +709,7 @@ bool llava_image_embed_make_with_clip_img_uhd(clip_ctx * ctx_clip, int n_threads
}
int n_img_pos;
if (!encode_image_with_clip_uhd(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
if (!encode_image_with_clip_uhd(ctx_clip, n_threads, img, image_embd, &n_img_pos, load_image_size)) {
LOG_TEE("%s: cannot encode image, aborting\n", __func__);
free(image_embd);
return false;