handle env

This commit is contained in:
Xuan Son Nguyen 2024-09-05 19:26:21 +02:00
parent 753782ae35
commit 60ae92bd54
3 changed files with 90 additions and 63 deletions

View file

@ -77,41 +77,6 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
//
// Environment variable utils
//
template<typename T>
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::string(value) : target;
}
template<typename T>
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stoi(value) : target;
}
template<typename T>
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stof(value) : target;
}
template<typename T>
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
if (value) {
std::string val(value);
target = val == "1" || val == "true";
}
}
// //
// CPU utils // CPU utils
// //
@ -390,6 +355,29 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto
} }
} }
// handle environment variables
for (auto & opt : options) {
std::string value;
if (opt.get_value_from_env(value)) {
try {
if (opt.handler_void && (value == "1" || value == "true")) {
opt.handler_void();
}
if (opt.handler_int) {
opt.handler_int(std::stoi(value));
}
if (opt.handler_string) {
opt.handler_string(value);
continue;
}
} catch (std::exception & e) {
throw std::invalid_argument(format(
"error while handling environment variable \"%s\": %s\n\n", opt.env.c_str(), e.what()));
}
}
}
// handle command line arguments
auto check_arg = [&](int i) { auto check_arg = [&](int i) {
if (i+1 >= argc) { if (i+1 >= argc) {
throw std::invalid_argument("expected value for argument"); throw std::invalid_argument("expected value for argument");
@ -405,6 +393,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto
throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str())); throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str()));
} }
auto opt = *arg_to_options[arg]; auto opt = *arg_to_options[arg];
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env.c_str(), arg.c_str());
}
try { try {
if (opt.handler_void) { if (opt.handler_void) {
opt.handler_void(); opt.handler_void();
@ -449,10 +440,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto
gpt_params_handle_model_default(params); gpt_params_handle_model_default(params);
if (params.hf_token.empty()) {
get_env("HF_TOKEN", params.hf_token);
}
if (params.escape) { if (params.escape) {
string_process_escapes(params.prompt); string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix); string_process_escapes(params.input_prefix);
@ -762,7 +749,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
params.cpuparams.n_threads = std::thread::hardware_concurrency(); params.cpuparams.n_threads = std::thread::hardware_concurrency();
} }
} }
)); ).set_env("LLAMA_ARG_THREADS"));
add_opt(llama_arg( add_opt(llama_arg(
{"-tb", "--threads-batch"}, "N", {"-tb", "--threads-batch"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads)", "number of threads to use during batch and prompt processing (default: same as --threads)",
@ -960,28 +947,28 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](int value) { [&params](int value) {
params.n_ctx = value; params.n_ctx = value;
} }
)); ).set_env("LLAMA_ARG_CTX_SIZE"));
add_opt(llama_arg( add_opt(llama_arg(
{"-n", "--predict"}, "N", {"-n", "--predict"}, "N",
format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict), format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict),
[&params](int value) { [&params](int value) {
params.n_predict = value; params.n_predict = value;
} }
)); ).set_env("LLAMA_ARG_N_PREDICT"));
add_opt(llama_arg( add_opt(llama_arg(
{"-b", "--batch-size"}, "N", {"-b", "--batch-size"}, "N",
format("logical maximum batch size (default: %d)", params.n_batch), format("logical maximum batch size (default: %d)", params.n_batch),
[&params](int value) { [&params](int value) {
params.n_batch = value; params.n_batch = value;
} }
)); ).set_env("LLAMA_ARG_BATCH"));
add_opt(llama_arg( add_opt(llama_arg(
{"-ub", "--ubatch-size"}, "N", {"-ub", "--ubatch-size"}, "N",
format("physical maximum batch size (default: %d)", params.n_ubatch), format("physical maximum batch size (default: %d)", params.n_ubatch),
[&params](int value) { [&params](int value) {
params.n_ubatch = value; params.n_ubatch = value;
} }
)); ).set_env("LLAMA_ARG_UBATCH"));
add_opt(llama_arg( add_opt(llama_arg(
{"--keep"}, "N", {"--keep"}, "N",
format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep), format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep),
@ -1002,7 +989,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() { [&params]() {
params.flash_attn = true; params.flash_attn = true;
} }
)); ).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(llama_arg( add_opt(llama_arg(
{"-p", "--prompt"}, "PROMPT", {"-p", "--prompt"}, "PROMPT",
"prompt to start generation with\n", "prompt to start generation with\n",
@ -1599,7 +1586,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) { [&params](std::string value) {
params.defrag_thold = std::stof(value); params.defrag_thold = std::stof(value);
} }
)); ).set_env("LLAMA_ARG_DEFRAG_THOLD"));
add_opt(llama_arg( add_opt(llama_arg(
{"-np", "--parallel"}, "N", {"-np", "--parallel"}, "N",
format("number of parallel sequences to decode (default: %d)", params.n_parallel), format("number of parallel sequences to decode (default: %d)", params.n_parallel),
@ -1620,14 +1607,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() { [&params]() {
params.cont_batching = true; params.cont_batching = true;
} }
)); ).set_env("LLAMA_ARG_CONT_BATCHING"));
add_opt(llama_arg( add_opt(llama_arg(
{"-nocb", "--no-cont-batching"}, {"-nocb", "--no-cont-batching"},
"disable continuous batching", "disable continuous batching",
[&params]() { [&params]() {
params.cont_batching = false; params.cont_batching = false;
} }
)); ).set_env("LLAMA_ARG_NO_CONT_BATCHING"));
add_opt(llama_arg( add_opt(llama_arg(
{"--mmproj"}, "FILE", {"--mmproj"}, "FILE",
"path to a multimodal projector file for LLaVA. see examples/llava/README.md", "path to a multimodal projector file for LLaVA. see examples/llava/README.md",
@ -1688,7 +1675,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
} }
} }
)); ).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(llama_arg( add_opt(llama_arg(
{"-ngld", "--gpu-layers-draft"}, "N", {"-ngld", "--gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model", "number of layers to store in VRAM for the draft model",
@ -1830,7 +1817,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) { [&params](std::string value) {
params.model = value; params.model = value;
} }
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
add_opt(llama_arg( add_opt(llama_arg(
{"-md", "--model-draft"}, "FNAME", {"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)", "draft model for speculative decoding (default: unused)",
@ -1844,28 +1831,28 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) { [&params](std::string value) {
params.model_url = value; params.model_url = value;
} }
)); ).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(llama_arg( add_opt(llama_arg(
{"-hfr", "--hf-repo"}, "REPO", {"-hfr", "--hf-repo"}, "REPO",
"Hugging Face model repository (default: unused)", "Hugging Face model repository (default: unused)",
[&params](std::string value) { [&params](std::string value) {
params.hf_repo = value; params.hf_repo = value;
} }
)); ).set_env("LLAMA_ARG_HF_REPO"));
add_opt(llama_arg( add_opt(llama_arg(
{"-hff", "--hf-file"}, "FILE", {"-hff", "--hf-file"}, "FILE",
"Hugging Face model file (default: unused)", "Hugging Face model file (default: unused)",
[&params](std::string value) { [&params](std::string value) {
params.hf_file = value; params.hf_file = value;
} }
)); ).set_env("LLAMA_ARG_HF_FILE"));
add_opt(llama_arg( add_opt(llama_arg(
{"-hft", "--hf-token"}, "TOKEN", {"-hft", "--hf-token"}, "TOKEN",
"Hugging Face access token (default: value from HF_TOKEN environment variable)", "Hugging Face access token (default: value from HF_TOKEN environment variable)",
[&params](std::string value) { [&params](std::string value) {
params.hf_token = value; params.hf_token = value;
} }
)); ).set_env("HF_TOKEN"));
add_opt(llama_arg( add_opt(llama_arg(
{"--context-file"}, "FNAME", {"--context-file"}, "FNAME",
"file to load context from (repeat to specify multiple files)", "file to load context from (repeat to specify multiple files)",
@ -2012,14 +1999,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](std::string value) { [&params](std::string value) {
params.hostname = value; params.hostname = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST"));
add_opt(llama_arg( add_opt(llama_arg(
{"--port"}, "PORT", {"--port"}, "PORT",
format("port to listen (default: %d)", params.port), format("port to listen (default: %d)", params.port),
[&params](int value) { [&params](int value) {
params.port = value; params.port = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
add_opt(llama_arg( add_opt(llama_arg(
{"--path"}, "PATH", {"--path"}, "PATH",
format("path to serve static files from (default: %s)", params.public_path.c_str()), format("path to serve static files from (default: %s)", params.public_path.c_str()),
@ -2028,19 +2015,19 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg( add_opt(llama_arg(
{"--embedding(s)"}, {"--embedding", "--embeddings"},
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
[&params]() { [&params]() {
params.embedding = true; params.embedding = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(llama_arg( add_opt(llama_arg(
{"--api-key"}, "KEY", {"--api-key"}, "KEY",
"API key to use for authentication (default: none)", "API key to use for authentication (default: none)",
[&params](std::string value) { [&params](std::string value) {
params.api_keys.push_back(value); params.api_keys.push_back(value);
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
add_opt(llama_arg( add_opt(llama_arg(
{"--api-key-file"}, "FNAME", {"--api-key-file"}, "FNAME",
"path to file containing API keys (default: none)", "path to file containing API keys (default: none)",
@ -2086,7 +2073,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params](int value) { [&params](int value) {
params.n_threads_http = value; params.n_threads_http = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
add_opt(llama_arg( add_opt(llama_arg(
{"-spf", "--system-prompt-file"}, "FNAME", {"-spf", "--system-prompt-file"}, "FNAME",
"set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications", "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications",
@ -2123,14 +2110,14 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
[&params]() { [&params]() {
params.endpoint_metrics = true; params.endpoint_metrics = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
add_opt(llama_arg( add_opt(llama_arg(
{"--no-slots"}, {"--no-slots"},
format("disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"), format("disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[&params]() { [&params]() {
params.endpoint_slots = false; params.endpoint_slots = false;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS"));
add_opt(llama_arg( add_opt(llama_arg(
{"--slot-save-path"}, "PATH", {"--slot-save-path"}, "PATH",
"path to save slot kv cache (default: disabled)", "path to save slot kv cache (default: disabled)",
@ -2157,7 +2144,7 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
} }
params.chat_template = value; params.chat_template = value;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(llama_arg( add_opt(llama_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY", {"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -316,6 +316,7 @@ struct llama_arg {
llama_arg(std::vector<std::string> args, std::string help, std::function<void(void)> handler) : args(args), help(help), handler_void(handler) {} llama_arg(std::vector<std::string> args, std::string help, std::function<void(void)> handler) : args(args), help(help), handler_void(handler) {}
// support 2 values for arg // support 2 values for arg
// note: env variable is not yet support for 2 values
llama_arg(std::vector<std::string> args, std::string value_hint, std::string value_hint_2, std::string help, std::function<void(std::string, std::string)> handler) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} llama_arg(std::vector<std::string> args, std::string value_hint, std::string value_hint_2, std::string help, std::function<void(std::string, std::string)> handler) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
llama_arg & set_examples(std::set<enum llama_example> examples) { llama_arg & set_examples(std::set<enum llama_example> examples) {
@ -324,6 +325,7 @@ struct llama_arg {
} }
llama_arg & set_env(std::string env) { llama_arg & set_env(std::string env) {
help = help + "\n(env: " + env + ")";
this->env = std::move(env); this->env = std::move(env);
return *this; return *this;
} }
@ -332,6 +334,20 @@ struct llama_arg {
return examples.find(ex) != examples.end(); return examples.find(ex) != examples.end();
} }
bool get_value_from_env(std::string & output) {
if (env.empty()) return false;
char * value = std::getenv(env.c_str());
if (value) {
output = value;
return true;
}
return false;
}
bool has_value_from_env() {
return std::getenv(env.c_str());
}
std::string to_string(bool markdown); std::string to_string(bool markdown);
}; };

View file

@ -63,5 +63,29 @@ int main(void) {
assert(params.n_predict == 6789); assert(params.n_predict == 6789);
assert(params.n_batch == 9090); assert(params.n_batch == 9090);
printf("test-arg-parser: test environment variables (valid + invalid usages)\n\n");
setenv("LLAMA_ARG_THREADS", "blah", true);
argv = {"binary_name"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
assert(params.model == "blah.gguf");
assert(params.cpuparams.n_threads == 1010);
printf("test-arg-parser: test environment variables being overwritten\n\n");
setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name", "-m", "overwritten.gguf"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
assert(params.model == "overwritten.gguf");
assert(params.cpuparams.n_threads == 1010);
printf("test-arg-parser: all tests OK\n\n"); printf("test-arg-parser: all tests OK\n\n");
} }