From 8224ca5775b7f09f088abf2379fcac25270085d4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 14 Oct 2023 10:43:13 +0200 Subject: [PATCH] wip refactor image loading --- examples/llava/llava-utils.h | 52 +++++++++++++++ examples/llava/llava.cpp | 122 +++++++++++------------------------ 2 files changed, 91 insertions(+), 83 deletions(-) diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h index 79e237c86..db8af6d6c 100644 --- a/examples/llava/llava-utils.h +++ b/examples/llava/llava-utils.h @@ -143,3 +143,55 @@ inline const char * sample(struct llama_context * ctx_llama, gpt_params & params eval_id(ctx_llama, id, n_past); return ret.c_str(); } + +static const char* IMG_BASE64_TAG_BEGIN = ""; + +static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { + begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); + end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); +} + +static bool prompt_contains_image(const std::string& prompt) { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + return (begin != std::string::npos); +} + +// replaces the base64 image tag in the prompt with `replacement` +static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) { + size_t img_base64_str_start, img_base64_str_end; + find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); + if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { + fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); + return false; + } + + auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); + auto base64_bytes_count = img_base64_str_end - base64_bytes_start; + auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); + + auto required_bytes = base64::required_encode_size(base64_str.size()); + auto img_bytes = std::vector(required_bytes); + auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); + auto img_bytes_len = img_bytes_end - img_bytes.begin(); + + auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img); + if (!img_loaded_ok) { + fprintf(stderr, "%s: could not load image from base64 string.\n", __func__); + return false; + } + + return true; +} + +static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + if (begin == std::string::npos || end == std::string::npos) { + return prompt; + } + auto pre = prompt.substr(0, begin); + auto post = prompt.substr(end+1); + return pre + replacement + post; +} diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index bfa2f72a5..7781d4222 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -37,58 +37,28 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return true; } -static const char* IMG_BASE64_TAG_BEGIN = ""; +bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) { -static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { - begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); - end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); -} - -static bool prompt_contains_image(const std::string& prompt) { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - return (begin != std::string::npos); -} - -// replaces the base64 image tag in the prompt with `replacement` -static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) { - size_t img_base64_str_start, img_base64_str_end; - find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); - if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { - fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); + float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); + if (!image_embd) { + fprintf(stderr, "Unable to allocate memory for image embeddings\n"); + free(image_embd); return false; } - auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); - auto base64_bytes_count = img_base64_str_end - base64_bytes_start; - auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); - printf("base64_str: '%s'\n", base64_str.c_str()); - - auto required_bytes = base64::required_encode_size(base64_str.size()); - auto img_bytes = std::vector(required_bytes); - auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); - auto img_bytes_len = img_bytes_end - img_bytes.begin(); - - auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img); - if (!img_loaded_ok) { - fprintf(stderr, "%s: could not load image from base64 string.\n", __func__); + int n_img_embd; + int n_img_pos; + float t_img_enc_ms; + if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { + fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); + free(image_embd); return false; } - return true; + ctx_llava->image_embd = image_embd; + retur true; } -static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - if (begin == std::string::npos || end == std::string::npos) { - return prompt; - } - auto pre = prompt.substr(0, begin); - auto post = prompt.substr(end+1); - return pre + replacement + post; -} struct llava_context * llava_init(gpt_params * params) { @@ -102,46 +72,6 @@ struct llava_context * llava_init(gpt_params * params) { auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - // load and preprocess the image - clip_image_u8 img; - - if (prompt_contains_image(prompt)) { - if (img_path) { - printf("using base64 encoded image instead of command line image path\n"); - } - if (!get_image_from_prompt(prompt, &img)) { - fprintf(stderr, "%s: can't load image from prompt\n", __func__); - clip_free(ctx_clip); - return NULL; - } - prompt = remove_image_from_prompt(prompt); - } else { - if (!clip_image_load_from_file(img_path, &img)) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); - clip_free(ctx_clip); - return NULL; - } - } - - float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); - if (!image_embd) { - fprintf(stderr, "Unable to allocate memory for image embeddings\n"); - return NULL; - } - - int n_img_embd; - int n_img_pos; - float t_img_enc_ms; - if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { - fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); - clip_free(ctx_clip); - return NULL; - } - - // we get the embeddings, free up the memory required for CLIP - clip_free(ctx_clip); - ctx_clip = NULL; - llama_backend_init(params->numa); llama_model_params model_params = llama_model_default_params(); @@ -194,6 +124,11 @@ struct llava_context * llava_init(gpt_params * params) { } void llava_free(struct llava_context * ctx_llava) { + if (ctx_llava->ctx_clip) { + clip_free(ctx_clip); + ctx_llava->ctx_clip = NULL; + } + llama_free(ctx_llava->ctx_llama); llama_free_model(ctx_llava->model); llama_backend_free(); @@ -249,6 +184,27 @@ int main(int argc, char ** argv) { return 1; } + // load and preprocess the image + clip_image_u8 img; + if (prompt_contains_image(prompt)) { + if (img_path) { + printf("using base64 encoded image instead of command line image path\n"); + } + if (!get_image_from_prompt(prompt, &img)) { + fprintf(stderr, "%s: can't load image from prompt\n", __func__); + clip_free(ctx_clip); + return NULL; + } + prompt = remove_image_from_prompt(prompt); + } else { + if (!clip_image_load_from_file(img_path, &img)) { + fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); + clip_free(ctx_clip); + return NULL; + } + } + llava_build_img_embed(ctx_llava, &img); + // process the prompt // llava chat format is "USER: \n\nASSISTANT:" llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());