diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index cf19a2f78..ffdad9c99 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -128,7 +128,33 @@ void llava_free(struct llava_context * ctx_llava) { llama_backend_free(); } -static void llava_process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) { + + +static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_image_pos) { + // load and preprocess the image + clip_image_u8 img; + auto prompt = params->prompt; + if (prompt_contains_image(prompt)) { + if (!params->image.empty()) { + printf("using base64 encoded image instead of command line image path\n"); + } + if (!get_image_from_prompt(prompt, &img)) { + fprintf(stderr, "%s: can't load image from prompt\n", __func__); + return false; + } + prompt = remove_image_from_prompt(prompt); + } else { + if (!clip_image_load_from_file(params->image.c_str(), &img)) { + fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); + return false; + } + } + llava_build_img_embed(ctx_llava, params->n_threads, &img, image_embd, n_image_pos); + + return true; +} + +static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; @@ -156,7 +182,6 @@ static void llava_process_prompt(struct llava_context * ctx_llava, float * image } - int main(int argc, char ** argv) { ggml_time_init(); @@ -178,32 +203,12 @@ int main(int argc, char ** argv) { return 1; } - // load and preprocess the image - clip_image_u8 img; - auto prompt = params.prompt; - if (prompt_contains_image(prompt)) { - if (!params.image.empty()) { - printf("using base64 encoded image instead of command line image path\n"); - } - if (!get_image_from_prompt(prompt, &img)) { - fprintf(stderr, "%s: can't load image from prompt\n", __func__); - llava_free(ctx_llava); - return 1; - } - prompt = remove_image_from_prompt(prompt); - } else { - if (!clip_image_load_from_file(params.image.c_str(), &img)) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, params.image.c_str()); - llava_free(ctx_llava); - return 1; - } - } float * image_embd; int n_image_pos; - llava_build_img_embed(ctx_llava, params.n_threads, &img, &image_embd, &n_image_pos); + load_image(ctx_llava, ¶ms, &image_embd, &n_image_pos); // process the prompt - llava_process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str()); + process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str()); llama_print_timings(ctx_llava->ctx_llama);