diff --git a/common/common.cpp b/common/common.cpp index 1d05bc7ca..59e829660 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -89,12 +89,19 @@ get_env(std::string name, T & target) { } template -static typename std::enable_if::value &&std::is_integral::value, void>::type +static typename std::enable_if::value && std::is_integral::value, void>::type get_env(std::string name, T & target) { char * value = std::getenv(name.c_str()); target = value ? std::stoi(value) : target; } +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stof(value) : target; +} + template static typename std::enable_if::value, void>::type get_env(std::string name, T & target) { @@ -332,6 +339,8 @@ void gpt_params_parse_from_env(gpt_params & params) { get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics); get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots); get_env("LLAMA_ARG_EMBEDDINGS", params.embedding); + get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn); + get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold); } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { diff --git a/examples/server/README.md b/examples/server/README.md index ab0260a8e..a96af79b0 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -262,6 +262,8 @@ Available environment variables (if specified, these variables will override par - `LLAMA_ARG_ENDPOINT_METRICS` - `LLAMA_ARG_ENDPOINT_SLOTS` - `LLAMA_ARG_EMBEDDINGS` +- `LLAMA_ARG_FLASH_ATTN` +- `LLAMA_ARG_DEFRAG_THOLD` ## Build