Update llava-cli.cpp to support comma-delimited image lists

Add in the ability to specify a comma-delimited list of images at the command line for batch-processing of multiple images without needing to reload the model file.
This commit is contained in:
cpumaxx 2024-03-25 16:44:38 -07:00 committed by GitHub
parent b06c16ef9f
commit 41d7c5eaca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -208,26 +208,28 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
}
static struct llava_context * llava_init(gpt_params * params) {
const char * clip_path = params->mmproj.c_str();
auto prompt = params->prompt;
if (prompt.empty()) {
prompt = "describe the image in detail.";
}
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
static struct llama_model * llava_init(gpt_params * params) {
llama_backend_init();
llama_numa_init(params->numa);
llama_model_params model_params = llama_model_params_from_gpt_params(*params);
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return NULL;
}
return model;
}
static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
const char * clip_path = params->mmproj.c_str();
auto prompt = params->prompt;
if (prompt.empty()) {
prompt = "describe the image in detail.";
}
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
@ -273,23 +275,42 @@ int main(int argc, char ** argv) {
return 1;
}
auto ctx_llava = llava_init(&params);
if (ctx_llava == NULL) {
fprintf(stderr, "%s: error: failed to init llava\n", __func__);
auto model = llava_init(&params);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
return 1;
}
auto image_embed = load_image(ctx_llava, &params);
if (!image_embed) {
return 1;
std::stringstream ss(params.image);
std::vector<std::string> imagestack;
while( ss.good() )
{
std::string substr;
getline( ss, substr, ',' );
imagestack.push_back( substr );
}
// process the prompt
process_prompt(ctx_llava, image_embed, &params, params.prompt);
for (auto & image : imagestack) {
llama_print_timings(ctx_llava->ctx_llama);
auto ctx_llava = llava_init_context(&params, model);
params.image=image;
llava_image_embed_free(image_embed);
llava_free(ctx_llava);
auto image_embed = load_image(ctx_llava, &params);
if (!image_embed) {
std::cerr << "error: failed to load image " << params.image << ". Terminating\n\n";
return 1;
}
// process the prompt
process_prompt(ctx_llava, image_embed, &params, params.prompt);
llama_print_timings(ctx_llava->ctx_llama);
llava_image_embed_free(image_embed);
ctx_llava->model = NULL;
llava_free(ctx_llava);
}
llama_free_model(model);
return 0;
}