From e8f1bd8b39a047aa7921c17e59d842f188ceb4b8 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 24 Jul 2024 15:46:04 +0000 Subject: [PATCH] common : support for lifecycle scripts --- common/common.cpp | 28 ++++++++++++++++++++++++++++ common/common.h | 11 +++++++++++ examples/main/main.cpp | 12 ++++++++++++ examples/server/server.cpp | 12 ++++++++++++ 4 files changed, 63 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index ec44a0552..a8d350be1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -331,6 +331,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--on-start") { + CHECK_ARG + params.on_start = argv[i]; + return true; + } + if (arg == "--on-inference-start") { + CHECK_ARG + params.on_inference_start = argv[i]; + return true; + } + if (arg == "--on-inference-end") { + CHECK_ARG + params.on_inference_end = argv[i]; + return true; + } if (arg == "-p" || arg == "--prompt") { CHECK_ARG params.prompt = argv[i]; @@ -1403,6 +1418,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); + options.push_back({ "*", " --on-start SCRIPT", "call the specified script at application startup" }); + options.push_back({ "*", " --on-inference-start SCRIPT", + "call the specified script before starting the inference" }); + options.push_back({ "*", " --on-inference-end SCRIPT", + "call the specified script when the inference is complete" }); options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); options.push_back({ "speculative", "-tbd, --threads-batch-draft N", "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); @@ -3223,3 +3243,11 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } + +void script_execute(const std::string & script) { + int result = std::system(script.c_str()); + + if (result != 0) { + fprintf(stderr, "%s: error: unable to execute script '%s'. exit code: %d\n", __func__, script.c_str(), result); + } +} diff --git a/common/common.h b/common/common.h index 8240ff99b..e77d4feb2 100644 --- a/common/common.h +++ b/common/common.h @@ -61,6 +61,11 @@ enum dimre_method { struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed + // lifecycle scripts + std::string on_start = ""; // script that will be called on application start + std::string on_inference_start = ""; // script that will be called when inference starts + std::string on_inference_end = ""; // script that will be called when inference ends + int32_t n_threads = cpu_get_num_math(); int32_t n_threads_draft = -1; int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) @@ -455,3 +460,9 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); + +// +// Script utils +// + +void script_execute(const std::string & script); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 61e960ea2..9edbd76bb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -137,6 +137,10 @@ int main(int argc, char ** argv) { return 1; } + if (!params.on_start.empty()) { + script_execute(params.on_start); + } + llama_sampling_params & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS @@ -534,6 +538,10 @@ int main(int argc, char ** argv) { exit(1); } + if (!params.on_inference_start.empty()) { + script_execute(params.on_inference_start); + } + if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); @@ -971,6 +979,10 @@ int main(int argc, char ** argv) { } } + if (!params.on_inference_end.empty()) { + script_execute(params.on_inference_end); + } + if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7813a2957..92e13e3c0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1106,6 +1106,10 @@ struct server_context { {"id_task", slot.id_task}, }); + if (!params.on_inference_start.empty()) { + script_execute(params.on_inference_start); + } + return true; } @@ -1913,6 +1917,10 @@ struct server_context { kv_cache_clear(); } + if (!params.on_inference_end.empty()) { + script_execute(params.on_inference_end); + } + return; } } @@ -2496,6 +2504,10 @@ int main(int argc, char ** argv) { return 1; } + if (!params.on_start.empty()) { + script_execute(params.on_start); + } + // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0;