diff --git a/common/common.cpp b/common/common.cpp index 9fa184725..39db42608 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1075,6 +1075,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.image.emplace_back(argv[i]); return true; } + if (arg == "--template") { + CHECK_ARG + params.templ = argv[i]; + return true; + } if (arg == "-i" || arg == "--interactive") { params.interactive = true; return true; @@ -1927,6 +1932,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "multi-modality" }); options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); + options.push_back({ "*", " --template STRING", "output template replaces [image] and [description] with generated output" }); + options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); diff --git a/common/common.h b/common/common.h index cb5e7f6df..a14c0f448 100644 --- a/common/common.h +++ b/common/common.h @@ -203,6 +203,7 @@ struct gpt_params { // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector std::vector image; // path to image file(s) + std::string templ = ""; // output template // embedding bool embedding = false; // get only sentence embedding diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 86b39f20e..25feec5c7 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -323,10 +323,27 @@ int main(int argc, char ** argv) { std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; return 1; } - + size_t pos = 0; + std::string str = params.templ; + // format output according to template + if (!params.templ.empty()){ + while((pos = str.find("[image]")) != std::string::npos) + str = str.replace(pos, 7, image); + pos = str.find("[description]"); + if (pos != std::string::npos) + std::cout << str.substr(0, pos); + else + std::cout << params.templ; + fflush(stdout); + } // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - + // terminate output according to template + if (!params.templ.empty()){ + if (pos != std::string::npos) + std::cout << str.substr(pos + 13); + fflush(stdout); + } llama_print_timings(ctx_llava->ctx_llama); llava_image_embed_free(image_embed); ctx_llava->model = NULL;