diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 89af6bb4f..a883d8d80 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -711,14 +711,48 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { // normalize: x = (x - mean) / std // TODO: implement bicubic interpolation instead of linear. -bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res) { +bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); return false; } - const int nx = img->nx; - const int ny = img->ny; + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + clip_image_u8 temp; // we will keep the input image data here temporarily + if (pad2square && img->nx != img->ny) { + int longer_side = std::max(img->nx, img->ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.size = 3 * longer_side * longer_side; + temp.data = new uint8_t[temp.size](); + uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA + + // fill with background color + for (int i = 0; i < temp.size; i++) { + temp.data[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img->ny; y++) { + for (int x = 0; x < img->nx; x++) { + const int i = 3 * (y * img->nx + x); + const int j = 3 * (y * temp.nx + x); + temp.data[j] = img->data[i]; + temp.data[j+1] = img->data[i+1]; + temp.data[j+2] = img->data[i+2]; + } + } + } else { + temp.nx = img->nx; + temp.ny = img->ny; + temp.size = img->size; + temp.data = img->data; + } + + const int nx = temp.nx; + const int ny = temp.ny; const int nx2 = ctx->vision_model.hparams.image_size; const int ny2 = ctx->vision_model.hparams.image_size; @@ -757,10 +791,10 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip const int j10 = 3 * (y1 * nx + x0) + c; const int j11 = 3 * (y1 * nx + x1) + c; - const float v00 = img->data[j00]; - const float v01 = img->data[j01]; - const float v10 = img->data[j10]; - const float v11 = img->data[j11]; + const float v00 = temp.data[j00]; + const float v01 = temp.data[j01]; + const float v10 = temp.data[j10]; + const float v11 = temp.data[j11]; const float v0 = v00 * (1.0f - dx) + v01 * dx; const float v1 = v10 * (1.0f - dx) + v11 * dx; @@ -1021,8 +1055,12 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i return true; } -size_t clip_embd_nbytes(struct clip_ctx * ctx) { +int clip_n_pos(struct clip_ctx * ctx) { auto & params = ctx->vision_model.hparams; - return (params.image_size / params.patch_size) * (params.image_size / params.patch_size) * 4096 * sizeof(float); + return (params.image_size / params.patch_size) * (params.image_size / params.patch_size); +} + +size_t clip_embd_nbytes(struct clip_ctx * ctx) { + return clip_n_pos(ctx) * 4096 * sizeof(float); } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 303b11436..c651332d6 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -25,6 +25,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity); void clip_free(struct clip_ctx * ctx); size_t clip_embd_nbytes(struct clip_ctx * ctx); +int clip_n_pos(struct clip_ctx * ctx); // RGB uint8 image struct clip_image_u8 { @@ -56,7 +57,7 @@ struct clip_image_f32_batch { struct clip_image_u8 * make_clip_image_u8(); struct clip_image_f32 * make_clip_image_f32(); bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); -bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res); +bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs, diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 62cdbf700..0e2fd0492 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -9,6 +9,8 @@ int main(int argc, char ** argv) { + ggml_time_init(); + gpt_params params; if (argc < 4) { @@ -30,24 +32,29 @@ int main(int argc, char ** argv) { auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - // load and preprocess the iamge + // load and preprocess the image clip_image_u8 img; clip_image_f32 img_res; clip_image_load_from_file(img_path, &img); - clip_image_preprocess(ctx_clip, &img, &img_res); + clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true); + + int n_img_pos = clip_n_pos(ctx_clip); float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); + if (!image_embd) { fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n"); return 1; } + const int64_t t_img_enc_start_us = ggml_time_us(); if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) { fprintf(stderr, "Unable to encode image\n"); return 1; } + const int64_t t_img_enc_end_us = ggml_time_us(); // we get the embeddings, free up the memory required for CLIP clip_free(ctx_clip); @@ -80,7 +87,7 @@ int main(int argc, char ** argv) { int n_past = 0; int max_tgt_len = 256; eval_string(ctx_llama, "user: ", params.n_batch, &n_past); - eval_image_embd(ctx_llama, image_embd, /*n_pos_image=*/ 576, params.n_batch, &n_past); + eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past); @@ -95,6 +102,9 @@ eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past); } printf("\n"); + const float img_enc_duration = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; + printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, img_enc_duration, img_enc_duration / n_img_pos); + llama_print_timings(ctx_llama); llama_free(ctx_llama);