cleanup memory usage around clip_image_*

This commit is contained in:
Damian Stewart 2023-10-15 00:18:04 +02:00
parent 2847ecf2dd
commit e3261ffad3
4 changed files with 43 additions and 32 deletions

View file

@ -22,25 +22,26 @@ static void show_additional_info(int /*argc*/, char ** argv) {
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) { static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
// load and preprocess the image // load and preprocess the image
clip_image_u8 img; clip_image_u8 * img = make_clip_image_u8();
auto prompt = params->prompt; auto prompt = params->prompt;
if (prompt_contains_image(prompt)) { if (prompt_contains_image(prompt)) {
if (!params->image.empty()) { if (!params->image.empty()) {
printf("using base64 encoded image instead of command line image path\n"); printf("using base64 encoded image instead of command line image path\n");
} }
if (!clip_image_load_from_prompt(prompt, &img)) { if (!clip_image_load_from_prompt(prompt, img)) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__); fprintf(stderr, "%s: can't load image from prompt\n", __func__);
return false; return false;
} }
params->prompt = remove_image_from_prompt(prompt); params->prompt = remove_image_from_prompt(prompt);
} else { } else {
if (!clip_image_load_from_file(params->image.c_str(), &img)) { if (!clip_image_load_from_file(params->image.c_str(), img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
return false; return false;
} }
} }
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, &img, image_embd, n_img_pos); bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos);
if (!image_embed_result) { if (!image_embed_result) {
clip_image_u8_free(img);
fprintf(stderr, "%s: coulnd't embed the image\n", __func__); fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
return false; return false;
} }

View file

@ -679,9 +679,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); } clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); } clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }
void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; }
void clip_image_f32_free(clip_image_f32 * img) { if (img->data) { delete[] img->data; } delete img; }
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) { static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
img->nx = nx; img->nx = nx;
img->ny = ny; img->ny = ny;
@ -726,39 +728,40 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
clip_image_u8 temp; // we will keep the input image data here temporarily clip_image_u8 * temp = make_clip_image_u8(); // we will keep the input image data here temporarily
if (pad2square && img->nx != img->ny) { if (pad2square && img->nx != img->ny) {
int longer_side = std::max(img->nx, img->ny); int longer_side = std::max(img->nx, img->ny);
temp.nx = longer_side; temp->nx = longer_side;
temp.ny = longer_side; temp->ny = longer_side;
temp.size = 3 * longer_side * longer_side; temp->size = 3 * longer_side * longer_side;
temp.data = new uint8_t[temp.size](); temp->data = new uint8_t[temp->size]();
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA
// fill with background color // fill with background color
for (size_t i = 0; i < temp.size; i++) { for (size_t i = 0; i < temp->size; i++) {
temp.data[i] = bc[i % 3]; temp->data[i] = bc[i % 3];
} }
// copy from the input image // copy from the input image
for (int y = 0; y < img->ny; y++) { for (int y = 0; y < img->ny; y++) {
for (int x = 0; x < img->nx; x++) { for (int x = 0; x < img->nx; x++) {
const int i = 3 * (y * img->nx + x); const int i = 3 * (y * img->nx + x);
const int j = 3 * (y * temp.nx + x); const int j = 3 * (y * temp->nx + x);
temp.data[j] = img->data[i]; temp->data[j] = img->data[i];
temp.data[j+1] = img->data[i+1]; temp->data[j+1] = img->data[i+1];
temp.data[j+2] = img->data[i+2]; temp->data[j+2] = img->data[i+2];
} }
} }
} else { } else {
temp.nx = img->nx; temp->nx = img->nx;
temp.ny = img->ny; temp->ny = img->ny;
temp.size = img->size; temp->size = img->size;
temp.data = img->data; temp->data = new uint8_t[temp->size]();
*temp->data = *img->data; // copy
} }
const int nx = temp.nx; const int nx = temp->nx;
const int ny = temp.ny; const int ny = temp->ny;
const int nx2 = ctx->vision_model.hparams.image_size; const int nx2 = ctx->vision_model.hparams.image_size;
const int ny2 = ctx->vision_model.hparams.image_size; const int ny2 = ctx->vision_model.hparams.image_size;
@ -797,10 +800,10 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
const int j10 = 3 * (y1 * nx + x0) + c; const int j10 = 3 * (y1 * nx + x0) + c;
const int j11 = 3 * (y1 * nx + x1) + c; const int j11 = 3 * (y1 * nx + x1) + c;
const float v00 = temp.data[j00]; const float v00 = temp->data[j00];
const float v01 = temp.data[j01]; const float v01 = temp->data[j01];
const float v10 = temp.data[j10]; const float v10 = temp->data[j10];
const float v11 = temp.data[j11]; const float v11 = temp->data[j11];
const float v0 = v00 * (1.0f - dx) + v01 * dx; const float v0 = v00 * (1.0f - dx) + v01 * dx;
const float v1 = v10 * (1.0f - dx) + v11 * dx; const float v1 = v10 * (1.0f - dx) + v11 * dx;
@ -815,6 +818,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
} }
} }
} }
clip_image_u8_free(temp);
return true; return true;
} }

View file

@ -33,7 +33,7 @@ int clip_n_mmproj_embd(struct clip_ctx * ctx);
struct clip_image_u8 { struct clip_image_u8 {
int nx; int nx;
int ny; int ny;
uint8_t * data; uint8_t * data = NULL;
size_t size; size_t size;
}; };
@ -42,7 +42,7 @@ struct clip_image_u8 {
struct clip_image_f32 { struct clip_image_f32 {
int nx; int nx;
int ny; int ny;
float * data; float * data = NULL;
size_t size; size_t size;
}; };
@ -58,8 +58,12 @@ struct clip_image_f32_batch {
struct clip_image_u8 * make_clip_image_u8(); struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32(); struct clip_image_f32 * make_clip_image_f32();
LLAMA_API void clip_image_u8_free(clip_image_u8 * img);
LLAMA_API void clip_image_f32_free(clip_image_f32 * img);
LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img); LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);

View file

@ -11,10 +11,10 @@
#include "base64.hpp" #include "base64.hpp"
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) { static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
clip_image_f32 img_res; clip_image_f32 * img_res = make_clip_image_f32();
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) { if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess image\n", __func__); fprintf(stderr, "%s: unable to preprocess image\n", __func__);
clip_image_f32_free(img_res);
return false; return false;
} }
@ -22,7 +22,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
*n_image_embd = clip_n_mmproj_embd(ctx_clip); *n_image_embd = clip_n_mmproj_embd(ctx_clip);
const int64_t t_img_enc_start_us = ggml_time_us(); const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) { bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
clip_image_f32_free(img_res);
if (!encoded) {
fprintf(stderr, "Unable to encode image\n"); fprintf(stderr, "Unable to encode image\n");
return false; return false;