From 07f120eb4e5d63ff9a924fa790aedcf14418fe95 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 8 Mar 2024 14:04:11 +0100 Subject: [PATCH] use set_pre_routing_handler for validate_api_key --- examples/server/server.cpp | 58 +++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 83b6c08e5..f7e087b90 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2154,7 +2154,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - sparams.api_keys.emplace_back(argv[i]); + sparams.api_keys.push_back(argv[i]); } else if (arg == "--api-key-file") { if (++i >= argc) { invalid_param = true; @@ -2832,7 +2832,8 @@ int main(int argc, char ** argv) { log_data["port"] = std::to_string(sparams.port); if (sparams.api_keys.size() == 1) { - log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); + auto key = sparams.api_keys[0]; + log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); } else if (sparams.api_keys.size() > 1) { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } @@ -2858,12 +2859,32 @@ int main(int argc, char ** argv) { } // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { + auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) { + // TODO: should we apply API key to all endpoints, including "/health" and "/models"? + static const std::set protected_endpoints = { + "/completion", + "/completions", + "/v1/completions", + "/chat/completions", + "/v1/chat/completions", + "/infill", + "/tokenize", + "/detokenize", + "/embedding", + "/embeddings", + "/v1/embeddings", + }; + // If API key is not set, skip validation if (sparams.api_keys.empty()) { return true; } + // If path is not in protected_endpoints list, skip validation + if (protected_endpoints.find(req.path) == protected_endpoints.end()) { + return true; + } + // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); @@ -2884,7 +2905,16 @@ int main(int argc, char ** argv) { return false; }; + // register server middlewares + svr.set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + if (sparams.public_path.empty()) { + // using embedded static files auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(content), len, mime_type); @@ -2913,11 +2943,8 @@ int main(int argc, char ** argv) { res.set_content(data.dump(), "application/json; charset=utf-8"); }); - const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } json data = json::parse(req.body); @@ -2993,9 +3020,9 @@ int main(int argc, char ** argv) { } }; - svr.Post("/completion", completions); // legacy - svr.Post("/completions", completions); - svr.Post("/v1/completions", completions); + svr.Post("/completion", handle_completions); // legacy + svr.Post("/completions", handle_completions); + svr.Post("/v1/completions", handle_completions); svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); @@ -3016,12 +3043,8 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), "application/json; charset=utf-8"); }); - const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { + const auto chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3095,11 +3118,8 @@ int main(int argc, char ** argv) { svr.Post("/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr.Post("/infill", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } json data = json::parse(req.body);