use set_pre_routing_handler for validate_api_key
This commit is contained in:
parent
d8a8bd4cc6
commit
07f120eb4e
1 changed files with 39 additions and 19 deletions
|
@ -2154,7 +2154,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.api_keys.emplace_back(argv[i]);
|
||||
sparams.api_keys.push_back(argv[i]);
|
||||
} else if (arg == "--api-key-file") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2832,7 +2832,8 @@ int main(int argc, char ** argv) {
|
|||
log_data["port"] = std::to_string(sparams.port);
|
||||
|
||||
if (sparams.api_keys.size() == 1) {
|
||||
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
|
||||
auto key = sparams.api_keys[0];
|
||||
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
|
||||
} else if (sparams.api_keys.size() > 1) {
|
||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
|
@ -2858,12 +2859,32 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// Middleware for API key validation
|
||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::set<std::string> protected_endpoints = {
|
||||
"/completion",
|
||||
"/completions",
|
||||
"/v1/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
"/infill",
|
||||
"/tokenize",
|
||||
"/detokenize",
|
||||
"/embedding",
|
||||
"/embeddings",
|
||||
"/v1/embeddings",
|
||||
};
|
||||
|
||||
// If API key is not set, skip validation
|
||||
if (sparams.api_keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If path is not in protected_endpoints list, skip validation
|
||||
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
|
||||
|
@ -2884,7 +2905,16 @@ int main(int argc, char ** argv) {
|
|||
return false;
|
||||
};
|
||||
|
||||
// register server middlewares
|
||||
svr.set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
if (sparams.public_path.empty()) {
|
||||
// using embedded static files
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||
|
@ -2913,11 +2943,8 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
@ -2993,9 +3020,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
svr.Post("/completion", completions); // legacy
|
||||
svr.Post("/completions", completions);
|
||||
svr.Post("/v1/completions", completions);
|
||||
svr.Post("/completion", handle_completions); // legacy
|
||||
svr.Post("/completions", handle_completions);
|
||||
svr.Post("/v1/completions", handle_completions);
|
||||
|
||||
svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
@ -3016,12 +3043,8 @@ int main(int argc, char ** argv) {
|
|||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3095,11 +3118,8 @@ int main(int argc, char ** argv) {
|
|||
svr.Post("/chat/completions", chat_completions);
|
||||
svr.Post("/v1/chat/completions", chat_completions);
|
||||
|
||||
svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
svr.Post("/infill", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!validate_api_key(req, res)) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue