use set_pre_routing_handler for validate_api_key

This commit is contained in:
ngxson 2024-03-08 14:04:11 +01:00
parent d8a8bd4cc6
commit 07f120eb4e

View file

@ -2154,7 +2154,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.api_keys.emplace_back(argv[i]); sparams.api_keys.push_back(argv[i]);
} else if (arg == "--api-key-file") { } else if (arg == "--api-key-file") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -2832,7 +2832,8 @@ int main(int argc, char ** argv) {
log_data["port"] = std::to_string(sparams.port); log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) { if (sparams.api_keys.size() == 1) {
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); auto key = sparams.api_keys[0];
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
} else if (sparams.api_keys.size() > 1) { } else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
} }
@ -2858,12 +2859,32 @@ int main(int argc, char ** argv) {
} }
// Middleware for API key validation // Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
static const std::set<std::string> protected_endpoints = {
"/completion",
"/completions",
"/v1/completions",
"/chat/completions",
"/v1/chat/completions",
"/infill",
"/tokenize",
"/detokenize",
"/embedding",
"/embeddings",
"/v1/embeddings",
};
// If API key is not set, skip validation // If API key is not set, skip validation
if (sparams.api_keys.empty()) { if (sparams.api_keys.empty()) {
return true; return true;
} }
// If path is not in protected_endpoints list, skip validation
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
return true;
}
// Check for API key in the header // Check for API key in the header
auto auth_header = req.get_header_value("Authorization"); auto auth_header = req.get_header_value("Authorization");
@ -2884,7 +2905,16 @@ int main(int argc, char ** argv) {
return false; return false;
}; };
// register server middlewares
svr.set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
if (sparams.public_path.empty()) { if (sparams.public_path.empty()) {
// using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type); res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
@ -2913,11 +2943,8 @@ int main(int argc, char ** argv) {
res.set_content(data.dump(), "application/json; charset=utf-8"); res.set_content(data.dump(), "application/json; charset=utf-8");
}); });
const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body); json data = json::parse(req.body);
@ -2993,9 +3020,9 @@ int main(int argc, char ** argv) {
} }
}; };
svr.Post("/completion", completions); // legacy svr.Post("/completion", handle_completions); // legacy
svr.Post("/completions", completions); svr.Post("/completions", handle_completions);
svr.Post("/v1/completions", completions); svr.Post("/v1/completions", handle_completions);
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) { svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
@ -3016,12 +3043,8 @@ int main(int argc, char ** argv) {
res.set_content(models.dump(), "application/json; charset=utf-8"); res.set_content(models.dump(), "application/json; charset=utf-8");
}); });
const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { const auto chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3095,11 +3118,8 @@ int main(int argc, char ** argv) {
svr.Post("/chat/completions", chat_completions); svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions);
svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { svr.Post("/infill", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body); json data = json::parse(req.body);