Merge 17dfcde615
into e25fb4b18f
This commit is contained in:
commit
15ecc09971
2 changed files with 60 additions and 36 deletions
|
@ -43,8 +43,9 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
|
||||||
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
||||||
- `--port`: Set the port to listen. Default: `8080`.
|
- `--port`: Set the port to listen. Default: `8080`.
|
||||||
- `--path`: path from which to serve static files (default examples/server/public)
|
- `--path`: path from which to serve static files (default examples/server/public)
|
||||||
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests to `/completion`, `/infill` and `/chat/completions` must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
||||||
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
||||||
|
- `--admin-key`: Set an admin key for request authorization. With an admin key set, requests to `/metrics` and `/slots` must have the Authorization header set with the api key as Bearer token. Additionally, `/health` will not show slots without the key. May be used multiple times to enable multiple valid keys.
|
||||||
- `--embedding`: Enable embedding extraction, Default: disabled.
|
- `--embedding`: Enable embedding extraction, Default: disabled.
|
||||||
- `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
|
- `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
|
||||||
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
|
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
|
||||||
|
|
|
@ -36,6 +36,7 @@ using json = nlohmann::json;
|
||||||
struct server_params {
|
struct server_params {
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
std::vector<std::string> admin_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
|
@ -2093,6 +2094,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||||
|
printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n");
|
||||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
|
@ -2161,6 +2163,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
sparams.public_path = argv[i];
|
sparams.public_path = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--admin-key")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.admin_keys.emplace_back(argv[i]);
|
||||||
|
}
|
||||||
else if (arg == "--api-key")
|
else if (arg == "--api-key")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -2816,6 +2827,38 @@ int main(int argc, char **argv)
|
||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Middleware for API key validation
|
||||||
|
auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector<std::string> &keys) -> bool {
|
||||||
|
// If API key is not set, skip validation
|
||||||
|
if (keys.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the header
|
||||||
|
auto auth_header = req.get_header_value("Authorization");
|
||||||
|
std::string prefix = "Bearer ";
|
||||||
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||||
|
std::string received_api_key = auth_header.substr(prefix.size());
|
||||||
|
if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the params
|
||||||
|
auto auth_param = req.get_param_value("key");
|
||||||
|
if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is invalid or not provided
|
||||||
|
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||||
|
res.status = 401; // Unauthorized
|
||||||
|
|
||||||
|
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
switch(current_state) {
|
switch(current_state) {
|
||||||
|
@ -2841,7 +2884,7 @@ int main(int argc, char **argv)
|
||||||
{"slots_idle", n_idle_slots},
|
{"slots_idle", n_idle_slots},
|
||||||
{"slots_processing", n_processing_slots}};
|
{"slots_processing", n_processing_slots}};
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
if (sparams.slots_endpoint && req.has_param("include_slots")) {
|
if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) {
|
||||||
health["slots"] = result.result_json["slots"];
|
health["slots"] = result.result_json["slots"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2866,7 +2909,10 @@ int main(int argc, char **argv)
|
||||||
});
|
});
|
||||||
|
|
||||||
if (sparams.slots_endpoint) {
|
if (sparams.slots_endpoint) {
|
||||||
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
|
if (!validate_key(req, res, sparams.admin_keys)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = llama.queue_tasks.get_new_id();
|
task.id = llama.queue_tasks.get_new_id();
|
||||||
|
@ -2886,7 +2932,10 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sparams.metrics_endpoint) {
|
if (sparams.metrics_endpoint) {
|
||||||
svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
|
if (!validate_key(req, res, sparams.admin_keys)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = llama.queue_tasks.get_new_id();
|
task.id = llama.queue_tasks.get_new_id();
|
||||||
|
@ -3046,32 +3095,6 @@ int main(int argc, char **argv)
|
||||||
llama.validate_model_chat_template(sparams);
|
llama.validate_model_chat_template(sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware for API key validation
|
|
||||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
|
||||||
// If API key is not set, skip validation
|
|
||||||
if (sparams.api_keys.empty()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for API key in the header
|
|
||||||
auto auth_header = req.get_header_value("Authorization");
|
|
||||||
std::string prefix = "Bearer ";
|
|
||||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
|
||||||
std::string received_api_key = auth_header.substr(prefix.size());
|
|
||||||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
|
||||||
return true; // API key is valid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// API key is invalid or not provided
|
|
||||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
|
||||||
res.status = 401; // Unauthorized
|
|
||||||
|
|
||||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
|
||||||
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// this is only called if no index.html is found in the public --path
|
// this is only called if no index.html is found in the public --path
|
||||||
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
|
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
@ -3112,10 +3135,10 @@ int main(int argc, char **argv)
|
||||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
if (!validate_key(req, res, sparams.api_keys)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
@ -3210,10 +3233,10 @@ int main(int argc, char **argv)
|
||||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
if (!validate_key(req, res, sparams.api_keys)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
||||||
|
@ -3293,10 +3316,10 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/chat/completions", chat_completions);
|
svr.Post("/chat/completions", chat_completions);
|
||||||
svr.Post("/v1/chat/completions", chat_completions);
|
svr.Post("/v1/chat/completions", chat_completions);
|
||||||
|
|
||||||
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
if (!validate_key(req, res, sparams.api_keys)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue