Add Admin key param and generalize key check

This commit is contained in:
Robey Holderith 2024-03-02 12:02:21 -08:00
parent 4a6e2d6142
commit ebc1decb10

View file

@ -36,6 +36,7 @@ using json = nlohmann::json;
struct server_params { struct server_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
std::vector<std::string> admin_keys;
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string chat_template = ""; std::string chat_template = "";
int32_t port = 8080; int32_t port = 8080;
@ -2060,6 +2061,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
@ -2128,6 +2130,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
sparams.public_path = argv[i]; sparams.public_path = argv[i];
} }
else if (arg == "--admin-key")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.admin_keys.emplace_back(argv[i]);
}
else if (arg == "--api-key") else if (arg == "--api-key")
{ {
if (++i >= argc) if (++i >= argc)
@ -2772,6 +2783,38 @@ int main(int argc, char **argv)
res.set_header("Access-Control-Allow-Headers", "*"); res.set_header("Access-Control-Allow-Headers", "*");
}); });
// Middleware for API key validation
auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector<std::string> &keys) -> bool {
// If API key is not set, skip validation
if (keys.empty()) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) {
return true; // API key is valid
}
}
// Check for API key in the params
auto auth_param = req.get_param_value("key");
if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) {
return true; // API key is valid
}
// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {});
return false;
};
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load(); server_state current_state = state.load();
switch(current_state) { switch(current_state) {
@ -2797,7 +2840,7 @@ int main(int argc, char **argv)
{"slots_idle", n_idle_slots}, {"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}}; {"slots_processing", n_processing_slots}};
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) { if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) {
health["slots"] = result.result_json["slots"]; health["slots"] = result.result_json["slots"];
} }
@ -2822,7 +2865,10 @@ int main(int argc, char **argv)
}); });
if (sparams.slots_endpoint) { if (sparams.slots_endpoint) {
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) {
if (!validate_key(req, res, sparams.admin_keys)) {
return;
}
// request slots data using task queue // request slots data using task queue
task_server task; task_server task;
task.id = llama.queue_tasks.get_new_id(); task.id = llama.queue_tasks.get_new_id();
@ -2842,7 +2888,10 @@ int main(int argc, char **argv)
} }
if (sparams.metrics_endpoint) { if (sparams.metrics_endpoint) {
svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) {
if (!validate_key(req, res, sparams.admin_keys)) {
return;
}
// request slots data using task queue // request slots data using task queue
task_server task; task_server task;
task.id = llama.queue_tasks.get_new_id(); task.id = llama.queue_tasks.get_new_id();
@ -3000,32 +3049,6 @@ int main(int argc, char **argv)
llama.validate_model_chat_template(sparams); llama.validate_model_chat_template(sparams);
} }
// Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation
if (sparams.api_keys.empty()) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid
}
}
// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {});
return false;
};
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res) svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{ {
@ -3066,10 +3089,10 @@ int main(int argc, char **argv)
res.set_content(data.dump(), "application/json; charset=utf-8"); res.set_content(data.dump(), "application/json; charset=utf-8");
}); });
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_key(req, res, sparams.api_keys)) {
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3163,10 +3186,10 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8"); res.set_content(models.dump(), "application/json; charset=utf-8");
}); });
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_key(req, res, sparams.api_keys)) {
return; return;
} }
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
@ -3246,10 +3269,10 @@ int main(int argc, char **argv)
svr.Post("/chat/completions", chat_completions); svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions);
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_key(req, res, sparams.api_keys)) {
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);