diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h
index 79e237c86..db8af6d6c 100644
--- a/examples/llava/llava-utils.h
+++ b/examples/llava/llava-utils.h
@@ -143,3 +143,55 @@ inline const char * sample(struct llama_context * ctx_llama, gpt_params & params
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}
+
+static const char* IMG_BASE64_TAG_BEGIN = "
";
+
+static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
+ begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
+ end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
+}
+
+static bool prompt_contains_image(const std::string& prompt) {
+ size_t begin, end;
+ find_image_tag_in_prompt(prompt, begin, end);
+ return (begin != std::string::npos);
+}
+
+// replaces the base64 image tag in the prompt with `replacement`
+static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
+ size_t img_base64_str_start, img_base64_str_end;
+ find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
+ if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
+ fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
+ return false;
+ }
+
+ auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
+ auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
+ auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
+
+ auto required_bytes = base64::required_encode_size(base64_str.size());
+ auto img_bytes = std::vector(required_bytes);
+ auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
+ auto img_bytes_len = img_bytes_end - img_bytes.begin();
+
+ auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
+ if (!img_loaded_ok) {
+ fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
+ return false;
+ }
+
+ return true;
+}
+
+static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
+ size_t begin, end;
+ find_image_tag_in_prompt(prompt, begin, end);
+ if (begin == std::string::npos || end == std::string::npos) {
+ return prompt;
+ }
+ auto pre = prompt.substr(0, begin);
+ auto post = prompt.substr(end+1);
+ return pre + replacement + post;
+}
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
index bfa2f72a5..7781d4222 100644
--- a/examples/llava/llava.cpp
+++ b/examples/llava/llava.cpp
@@ -37,58 +37,28 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
return true;
}
-static const char* IMG_BASE64_TAG_BEGIN = "
";
+bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) {
-static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
- begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
- end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
-}
-
-static bool prompt_contains_image(const std::string& prompt) {
- size_t begin, end;
- find_image_tag_in_prompt(prompt, begin, end);
- return (begin != std::string::npos);
-}
-
-// replaces the base64 image tag in the prompt with `replacement`
-static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
- size_t img_base64_str_start, img_base64_str_end;
- find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
- if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
- fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
+ float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
+ if (!image_embd) {
+ fprintf(stderr, "Unable to allocate memory for image embeddings\n");
+ free(image_embd);
return false;
}
- auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
- auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
- auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
- printf("base64_str: '%s'\n", base64_str.c_str());
-
- auto required_bytes = base64::required_encode_size(base64_str.size());
- auto img_bytes = std::vector(required_bytes);
- auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
- auto img_bytes_len = img_bytes_end - img_bytes.begin();
-
- auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
- if (!img_loaded_ok) {
- fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
+ int n_img_embd;
+ int n_img_pos;
+ float t_img_enc_ms;
+ if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
+ fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
+ free(image_embd);
return false;
}
- return true;
+ ctx_llava->image_embd = image_embd;
+ retur true;
}
-static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
- size_t begin, end;
- find_image_tag_in_prompt(prompt, begin, end);
- if (begin == std::string::npos || end == std::string::npos) {
- return prompt;
- }
- auto pre = prompt.substr(0, begin);
- auto post = prompt.substr(end+1);
- return pre + replacement + post;
-}
struct llava_context * llava_init(gpt_params * params) {
@@ -102,46 +72,6 @@ struct llava_context * llava_init(gpt_params * params) {
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
- // load and preprocess the image
- clip_image_u8 img;
-
- if (prompt_contains_image(prompt)) {
- if (img_path) {
- printf("using base64 encoded image instead of command line image path\n");
- }
- if (!get_image_from_prompt(prompt, &img)) {
- fprintf(stderr, "%s: can't load image from prompt\n", __func__);
- clip_free(ctx_clip);
- return NULL;
- }
- prompt = remove_image_from_prompt(prompt);
- } else {
- if (!clip_image_load_from_file(img_path, &img)) {
- fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
- clip_free(ctx_clip);
- return NULL;
- }
- }
-
- float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
- if (!image_embd) {
- fprintf(stderr, "Unable to allocate memory for image embeddings\n");
- return NULL;
- }
-
- int n_img_embd;
- int n_img_pos;
- float t_img_enc_ms;
- if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
- fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
- clip_free(ctx_clip);
- return NULL;
- }
-
- // we get the embeddings, free up the memory required for CLIP
- clip_free(ctx_clip);
- ctx_clip = NULL;
-
llama_backend_init(params->numa);
llama_model_params model_params = llama_model_default_params();
@@ -194,6 +124,11 @@ struct llava_context * llava_init(gpt_params * params) {
}
void llava_free(struct llava_context * ctx_llava) {
+ if (ctx_llava->ctx_clip) {
+ clip_free(ctx_clip);
+ ctx_llava->ctx_clip = NULL;
+ }
+
llama_free(ctx_llava->ctx_llama);
llama_free_model(ctx_llava->model);
llama_backend_free();
@@ -249,6 +184,27 @@ int main(int argc, char ** argv) {
return 1;
}
+ // load and preprocess the image
+ clip_image_u8 img;
+ if (prompt_contains_image(prompt)) {
+ if (img_path) {
+ printf("using base64 encoded image instead of command line image path\n");
+ }
+ if (!get_image_from_prompt(prompt, &img)) {
+ fprintf(stderr, "%s: can't load image from prompt\n", __func__);
+ clip_free(ctx_clip);
+ return NULL;
+ }
+ prompt = remove_image_from_prompt(prompt);
+ } else {
+ if (!clip_image_load_from_file(img_path, &img)) {
+ fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
+ clip_free(ctx_clip);
+ return NULL;
+ }
+ }
+ llava_build_img_embed(ctx_llava, &img);
+
// process the prompt
// llava chat format is "USER: \n\nASSISTANT:"
llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());