From 41d7c5eaca2a4b8c36b8d11abecb0ecc7d9623ea Mon Sep 17 00:00:00 2001 From: cpumaxx <163466046+cpumaxx@users.noreply.github.com> Date: Mon, 25 Mar 2024 16:44:38 -0700 Subject: [PATCH] Update llava-cli.cpp to support comma-delimited image lists Add in the ability to specify a comma-delimited list of images at the command line for batch-processing of multiple images without needing to reload the model file. --- examples/llava/llava-cli.cpp | 67 +++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index e29da6cb2..673f960ef 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -208,26 +208,28 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ } -static struct llava_context * llava_init(gpt_params * params) { - const char * clip_path = params->mmproj.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - +static struct llama_model * llava_init(gpt_params * params) { llama_backend_init(); llama_numa_init(params->numa); - + llama_model_params model_params = llama_model_params_from_gpt_params(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return NULL; } + return model; +} + +static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -273,23 +275,42 @@ int main(int argc, char ** argv) { return 1; } - auto ctx_llava = llava_init(¶ms); - if (ctx_llava == NULL) { - fprintf(stderr, "%s: error: failed to init llava\n", __func__); + auto model = llava_init(¶ms); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init llava model\n", __func__); return 1; } - auto image_embed = load_image(ctx_llava, ¶ms); - if (!image_embed) { - return 1; + std::stringstream ss(params.image); + std::vector imagestack; + + while( ss.good() ) + { + std::string substr; + getline( ss, substr, ',' ); + imagestack.push_back( substr ); } - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + for (auto & image : imagestack) { - llama_print_timings(ctx_llava->ctx_llama); + auto ctx_llava = llava_init_context(¶ms, model); + params.image=image; - llava_image_embed_free(image_embed); - llava_free(ctx_llava); + auto image_embed = load_image(ctx_llava, ¶ms); + if (!image_embed) { + std::cerr << "error: failed to load image " << params.image << ". Terminating\n\n"; + return 1; + } + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_print_timings(ctx_llava->ctx_llama); + + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); + } + llama_free_model(model); return 0; }