Add Admin key param and generalize key check
This commit is contained in:
parent
4a6e2d6142
commit
ebc1decb10
1 changed files with 58 additions and 35 deletions
|
@ -36,6 +36,7 @@ using json = nlohmann::json;
|
|||
struct server_params {
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::vector<std::string> api_keys;
|
||||
std::vector<std::string> admin_keys;
|
||||
std::string public_path = "examples/server/public";
|
||||
std::string chat_template = "";
|
||||
int32_t port = 8080;
|
||||
|
@ -2060,6 +2061,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||
printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n");
|
||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||
|
@ -2128,6 +2130,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
sparams.public_path = argv[i];
|
||||
}
|
||||
else if (arg == "--admin-key")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.admin_keys.emplace_back(argv[i]);
|
||||
}
|
||||
else if (arg == "--api-key")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2772,6 +2783,38 @@ int main(int argc, char **argv)
|
|||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
});
|
||||
|
||||
// Middleware for API key validation
|
||||
auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector<std::string> &keys) -> bool {
|
||||
// If API key is not set, skip validation
|
||||
if (keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
}
|
||||
|
||||
// Check for API key in the params
|
||||
auto auth_param = req.get_param_value("key");
|
||||
if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
res.status = 401; // Unauthorized
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
server_state current_state = state.load();
|
||||
switch(current_state) {
|
||||
|
@ -2797,7 +2840,7 @@ int main(int argc, char **argv)
|
|||
{"slots_idle", n_idle_slots},
|
||||
{"slots_processing", n_processing_slots}};
|
||||
res.status = 200; // HTTP OK
|
||||
if (sparams.slots_endpoint && req.has_param("include_slots")) {
|
||||
if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) {
|
||||
health["slots"] = result.result_json["slots"];
|
||||
}
|
||||
|
||||
|
@ -2822,7 +2865,10 @@ int main(int argc, char **argv)
|
|||
});
|
||||
|
||||
if (sparams.slots_endpoint) {
|
||||
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) {
|
||||
svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
if (!validate_key(req, res, sparams.admin_keys)) {
|
||||
return;
|
||||
}
|
||||
// request slots data using task queue
|
||||
task_server task;
|
||||
task.id = llama.queue_tasks.get_new_id();
|
||||
|
@ -2842,7 +2888,10 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
if (sparams.metrics_endpoint) {
|
||||
svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) {
|
||||
svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
if (!validate_key(req, res, sparams.admin_keys)) {
|
||||
return;
|
||||
}
|
||||
// request slots data using task queue
|
||||
task_server task;
|
||||
task.id = llama.queue_tasks.get_new_id();
|
||||
|
@ -3000,32 +3049,6 @@ int main(int argc, char **argv)
|
|||
llama.validate_model_chat_template(sparams);
|
||||
}
|
||||
|
||||
// Middleware for API key validation
|
||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||
// If API key is not set, skip validation
|
||||
if (sparams.api_keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
res.status = 401; // Unauthorized
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
|
||||
{
|
||||
|
@ -3066,10 +3089,10 @@ int main(int argc, char **argv)
|
|||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||
svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
if (!validate_key(req, res, sparams.api_keys)) {
|
||||
return;
|
||||
}
|
||||
json data = json::parse(req.body);
|
||||
|
@ -3163,10 +3186,10 @@ int main(int argc, char **argv)
|
|||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||
const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
if (!validate_key(req, res, sparams.api_keys)) {
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
||||
|
@ -3246,10 +3269,10 @@ int main(int argc, char **argv)
|
|||
svr.Post("/chat/completions", chat_completions);
|
||||
svr.Post("/v1/chat/completions", chat_completions);
|
||||
|
||||
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||
svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
if (!validate_key(req, res, sparams.api_keys)) {
|
||||
return;
|
||||
}
|
||||
json data = json::parse(req.body);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue