introduce pad-to-square mode for non-square images

This commit is contained in:
M. Yusuf Sarıgöz 2023-10-09 23:53:29 +03:00
parent 4759bfd64c
commit 325d240061
3 changed files with 62 additions and 13 deletions

View file

@ -711,14 +711,48 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
// normalize: x = (x - mean) / std
// TODO: implement bicubic interpolation instead of linear.
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res) {
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
if (!ctx->has_vision_encoder) {
printf("This gguf file seems to have no vision encoder\n");
return false;
}
const int nx = img->nx;
const int ny = img->ny;
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
clip_image_u8 temp; // we will keep the input image data here temporarily
if (pad2square && img->nx != img->ny) {
int longer_side = std::max(img->nx, img->ny);
temp.nx = longer_side;
temp.ny = longer_side;
temp.size = 3 * longer_side * longer_side;
temp.data = new uint8_t[temp.size]();
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA
// fill with background color
for (int i = 0; i < temp.size; i++) {
temp.data[i] = bc[i % 3];
}
// copy from the input image
for (int y = 0; y < img->ny; y++) {
for (int x = 0; x < img->nx; x++) {
const int i = 3 * (y * img->nx + x);
const int j = 3 * (y * temp.nx + x);
temp.data[j] = img->data[i];
temp.data[j+1] = img->data[i+1];
temp.data[j+2] = img->data[i+2];
}
}
} else {
temp.nx = img->nx;
temp.ny = img->ny;
temp.size = img->size;
temp.data = img->data;
}
const int nx = temp.nx;
const int ny = temp.ny;
const int nx2 = ctx->vision_model.hparams.image_size;
const int ny2 = ctx->vision_model.hparams.image_size;
@ -757,10 +791,10 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
const int j10 = 3 * (y1 * nx + x0) + c;
const int j11 = 3 * (y1 * nx + x1) + c;
const float v00 = img->data[j00];
const float v01 = img->data[j01];
const float v10 = img->data[j10];
const float v11 = img->data[j11];
const float v00 = temp.data[j00];
const float v01 = temp.data[j01];
const float v10 = temp.data[j10];
const float v11 = temp.data[j11];
const float v0 = v00 * (1.0f - dx) + v01 * dx;
const float v1 = v10 * (1.0f - dx) + v11 * dx;
@ -1021,8 +1055,12 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
return true;
}
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
int clip_n_pos(struct clip_ctx * ctx) {
auto & params = ctx->vision_model.hparams;
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size) * 4096 * sizeof(float);
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
}
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
return clip_n_pos(ctx) * 4096 * sizeof(float);
}

View file

@ -25,6 +25,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
void clip_free(struct clip_ctx * ctx);
size_t clip_embd_nbytes(struct clip_ctx * ctx);
int clip_n_pos(struct clip_ctx * ctx);
// RGB uint8 image
struct clip_image_u8 {
@ -56,7 +57,7 @@ struct clip_image_f32_batch {
struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32();
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,

View file

@ -9,6 +9,8 @@
int main(int argc, char ** argv) {
ggml_time_init();
gpt_params params;
if (argc < 4) {
@ -30,24 +32,29 @@ int main(int argc, char ** argv) {
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
// load and preprocess the iamge
// load and preprocess the image
clip_image_u8 img;
clip_image_f32 img_res;
clip_image_load_from_file(img_path, &img);
clip_image_preprocess(ctx_clip, &img, &img_res);
clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true);
int n_img_pos = clip_n_pos(ctx_clip);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n");
return 1;
}
const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
fprintf(stderr, "Unable to encode image\n");
return 1;
}
const int64_t t_img_enc_end_us = ggml_time_us();
// we get the embeddings, free up the memory required for CLIP
clip_free(ctx_clip);
@ -80,7 +87,7 @@ int main(int argc, char ** argv) {
int n_past = 0;
int max_tgt_len = 256;
eval_string(ctx_llama, "user: ", params.n_batch, &n_past);
eval_image_embd(ctx_llama, image_embd, /*n_pos_image=*/ 576, params.n_batch, &n_past);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
@ -95,6 +102,9 @@ eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
}
printf("\n");
const float img_enc_duration = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, img_enc_duration, img_enc_duration / n_img_pos);
llama_print_timings(ctx_llama);
llama_free(ctx_llama);