introduce pad-to-square mode for non-square images
This commit is contained in:
parent
4759bfd64c
commit
325d240061
3 changed files with 62 additions and 13 deletions
|
@ -711,14 +711,48 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||||
|
|
||||||
// normalize: x = (x - mean) / std
|
// normalize: x = (x - mean) / std
|
||||||
// TODO: implement bicubic interpolation instead of linear.
|
// TODO: implement bicubic interpolation instead of linear.
|
||||||
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res) {
|
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
|
||||||
if (!ctx->has_vision_encoder) {
|
if (!ctx->has_vision_encoder) {
|
||||||
printf("This gguf file seems to have no vision encoder\n");
|
printf("This gguf file seems to have no vision encoder\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nx = img->nx;
|
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||||
const int ny = img->ny;
|
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||||
|
|
||||||
|
clip_image_u8 temp; // we will keep the input image data here temporarily
|
||||||
|
if (pad2square && img->nx != img->ny) {
|
||||||
|
int longer_side = std::max(img->nx, img->ny);
|
||||||
|
temp.nx = longer_side;
|
||||||
|
temp.ny = longer_side;
|
||||||
|
temp.size = 3 * longer_side * longer_side;
|
||||||
|
temp.data = new uint8_t[temp.size]();
|
||||||
|
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA
|
||||||
|
|
||||||
|
// fill with background color
|
||||||
|
for (int i = 0; i < temp.size; i++) {
|
||||||
|
temp.data[i] = bc[i % 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy from the input image
|
||||||
|
for (int y = 0; y < img->ny; y++) {
|
||||||
|
for (int x = 0; x < img->nx; x++) {
|
||||||
|
const int i = 3 * (y * img->nx + x);
|
||||||
|
const int j = 3 * (y * temp.nx + x);
|
||||||
|
temp.data[j] = img->data[i];
|
||||||
|
temp.data[j+1] = img->data[i+1];
|
||||||
|
temp.data[j+2] = img->data[i+2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
temp.nx = img->nx;
|
||||||
|
temp.ny = img->ny;
|
||||||
|
temp.size = img->size;
|
||||||
|
temp.data = img->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int nx = temp.nx;
|
||||||
|
const int ny = temp.ny;
|
||||||
|
|
||||||
const int nx2 = ctx->vision_model.hparams.image_size;
|
const int nx2 = ctx->vision_model.hparams.image_size;
|
||||||
const int ny2 = ctx->vision_model.hparams.image_size;
|
const int ny2 = ctx->vision_model.hparams.image_size;
|
||||||
|
@ -757,10 +791,10 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
|
||||||
const int j10 = 3 * (y1 * nx + x0) + c;
|
const int j10 = 3 * (y1 * nx + x0) + c;
|
||||||
const int j11 = 3 * (y1 * nx + x1) + c;
|
const int j11 = 3 * (y1 * nx + x1) + c;
|
||||||
|
|
||||||
const float v00 = img->data[j00];
|
const float v00 = temp.data[j00];
|
||||||
const float v01 = img->data[j01];
|
const float v01 = temp.data[j01];
|
||||||
const float v10 = img->data[j10];
|
const float v10 = temp.data[j10];
|
||||||
const float v11 = img->data[j11];
|
const float v11 = temp.data[j11];
|
||||||
|
|
||||||
const float v0 = v00 * (1.0f - dx) + v01 * dx;
|
const float v0 = v00 * (1.0f - dx) + v01 * dx;
|
||||||
const float v1 = v10 * (1.0f - dx) + v11 * dx;
|
const float v1 = v10 * (1.0f - dx) + v11 * dx;
|
||||||
|
@ -1021,8 +1055,12 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
|
int clip_n_pos(struct clip_ctx * ctx) {
|
||||||
auto & params = ctx->vision_model.hparams;
|
auto & params = ctx->vision_model.hparams;
|
||||||
|
|
||||||
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size) * 4096 * sizeof(float);
|
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
|
||||||
|
return clip_n_pos(ctx) * 4096 * sizeof(float);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
|
||||||
void clip_free(struct clip_ctx * ctx);
|
void clip_free(struct clip_ctx * ctx);
|
||||||
|
|
||||||
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
||||||
|
int clip_n_pos(struct clip_ctx * ctx);
|
||||||
|
|
||||||
// RGB uint8 image
|
// RGB uint8 image
|
||||||
struct clip_image_u8 {
|
struct clip_image_u8 {
|
||||||
|
@ -56,7 +57,7 @@ struct clip_image_f32_batch {
|
||||||
struct clip_image_u8 * make_clip_image_u8();
|
struct clip_image_u8 * make_clip_image_u8();
|
||||||
struct clip_image_f32 * make_clip_image_f32();
|
struct clip_image_f32 * make_clip_image_f32();
|
||||||
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||||
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res);
|
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
|
||||||
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
|
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
|
||||||
|
|
||||||
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
|
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
|
||||||
|
|
|
@ -9,6 +9,8 @@
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
ggml_time_init();
|
||||||
|
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (argc < 4) {
|
if (argc < 4) {
|
||||||
|
@ -30,24 +32,29 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||||
|
|
||||||
// load and preprocess the iamge
|
// load and preprocess the image
|
||||||
clip_image_u8 img;
|
clip_image_u8 img;
|
||||||
clip_image_f32 img_res;
|
clip_image_f32 img_res;
|
||||||
clip_image_load_from_file(img_path, &img);
|
clip_image_load_from_file(img_path, &img);
|
||||||
clip_image_preprocess(ctx_clip, &img, &img_res);
|
|
||||||
|
|
||||||
|
clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true);
|
||||||
|
|
||||||
|
int n_img_pos = clip_n_pos(ctx_clip);
|
||||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
|
|
||||||
if (!image_embd) {
|
if (!image_embd) {
|
||||||
fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n");
|
fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n");
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||||
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
|
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
|
||||||
fprintf(stderr, "Unable to encode image\n");
|
fprintf(stderr, "Unable to encode image\n");
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||||
|
|
||||||
// we get the embeddings, free up the memory required for CLIP
|
// we get the embeddings, free up the memory required for CLIP
|
||||||
clip_free(ctx_clip);
|
clip_free(ctx_clip);
|
||||||
|
@ -80,7 +87,7 @@ int main(int argc, char ** argv) {
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
int max_tgt_len = 256;
|
int max_tgt_len = 256;
|
||||||
eval_string(ctx_llama, "user: ", params.n_batch, &n_past);
|
eval_string(ctx_llama, "user: ", params.n_batch, &n_past);
|
||||||
eval_image_embd(ctx_llama, image_embd, /*n_pos_image=*/ 576, params.n_batch, &n_past);
|
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
|
eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
|
||||||
|
|
||||||
|
@ -95,6 +102,9 @@ eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
const float img_enc_duration = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
||||||
|
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, img_enc_duration, img_enc_duration / n_img_pos);
|
||||||
|
|
||||||
llama_print_timings(ctx_llama);
|
llama_print_timings(ctx_llama);
|
||||||
|
|
||||||
llama_free(ctx_llama);
|
llama_free(ctx_llama);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue