diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2b2f4a0f4..13fa0a565 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -36,6 +36,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; std::vector api_keys; + std::vector admin_keys; std::string public_path = "examples/server/public"; std::string chat_template = ""; int32_t port = 8080; @@ -2060,6 +2061,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2128,6 +2130,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.public_path = argv[i]; } + else if (arg == "--admin-key") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.admin_keys.emplace_back(argv[i]); + } else if (arg == "--api-key") { if (++i >= argc) @@ -2772,6 +2783,38 @@ int main(int argc, char **argv) res.set_header("Access-Control-Allow-Headers", "*"); }); + // Middleware for API key validation + auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector &keys) -> bool { + // If API key is not set, skip validation + if (keys.empty()) { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) { + return true; // API key is valid + } + } + + // Check for API key in the params + auto auth_param = req.get_param_value("key"); + if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) { + return true; // API key is valid + } + + // API key is invalid or not provided + res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); + res.status = 401; // Unauthorized + + LOG_WARNING("Unauthorized: Invalid API Key", {}); + + return false; + }; + svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { @@ -2797,7 +2840,7 @@ int main(int argc, char **argv) {"slots_idle", n_idle_slots}, {"slots_processing", n_processing_slots}}; res.status = 200; // HTTP OK - if (sparams.slots_endpoint && req.has_param("include_slots")) { + if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) { health["slots"] = result.result_json["slots"]; } @@ -2822,7 +2865,10 @@ int main(int argc, char **argv) }); if (sparams.slots_endpoint) { - svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) { + if (!validate_key(req, res, sparams.admin_keys)) { + return; + } // request slots data using task queue task_server task; task.id = llama.queue_tasks.get_new_id(); @@ -2842,7 +2888,10 @@ int main(int argc, char **argv) } if (sparams.metrics_endpoint) { - svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) { + if (!validate_key(req, res, sparams.admin_keys)) { + return; + } // request slots data using task queue task_server task; task.id = llama.queue_tasks.get_new_id(); @@ -3000,32 +3049,6 @@ int main(int argc, char **argv) llama.validate_model_chat_template(sparams); } - // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { - // If API key is not set, skip validation - if (sparams.api_keys.empty()) { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { - return true; // API key is valid - } - } - - // API key is invalid or not provided - res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); - res.status = 401; // Unauthorized - - LOG_WARNING("Unauthorized: Invalid API Key", {}); - - return false; - }; - // this is only called if no index.html is found in the public --path svr.Get("/", [](const httplib::Request &, httplib::Response &res) { @@ -3066,10 +3089,10 @@ int main(int argc, char **argv) res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) + svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { + if (!validate_key(req, res, sparams.api_keys)) { return; } json data = json::parse(req.body); @@ -3163,10 +3186,10 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); - const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) + const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { + if (!validate_key(req, res, sparams.api_keys)) { return; } json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); @@ -3246,10 +3269,10 @@ int main(int argc, char **argv) svr.Post("/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) + svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { + if (!validate_key(req, res, sparams.api_keys)) { return; } json data = json::parse(req.body);