Allow setting multiple CORS enabled origins.
* Allow setting multiple CORS enabled origins. * Add both "http://localhost:8080" and "http://127.0.0.1:8080" by default. * Move CORS logging below server startup to make it more visible.
This commit is contained in:
parent
2dc03f404d
commit
f4ca2301bb
1 changed files with 25 additions and 14 deletions
|
@ -37,7 +37,7 @@ struct server_params {
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
std::string http_cors_origin = "http://localhost:8080";
|
std::vector<std::string> http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"};
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
|
@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||||
printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. (default: %s)\n", sparams.http_cors_origin.c_str());
|
printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default.");
|
||||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
|
@ -2141,7 +2141,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.http_cors_origin = argv[i];
|
std::string cors_origins = argv[i];
|
||||||
|
std::stringstream ss(cors_origins);
|
||||||
|
std::string cors_origin;
|
||||||
|
while (std::getline(ss, cors_origin, ',')) {
|
||||||
|
sparams.http_cors_origin.push_back(cors_origin);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (arg == "--path")
|
else if (arg == "--path")
|
||||||
{
|
{
|
||||||
|
@ -2788,8 +2793,7 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
svr.set_default_headers({{"Server", "llama.cpp"}});
|
svr.set_default_headers({{"Server", "llama.cpp"}});
|
||||||
|
|
||||||
// Disallow CORS requests from any origin not specified in the --public-domain flag
|
// Disallow CORS requests from any origin not specified in the --http-cors-origin flag
|
||||||
LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}});
|
|
||||||
svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
|
svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
|
@ -2801,9 +2805,10 @@ int main(int argc, char **argv)
|
||||||
res.status = 200;
|
res.status = 200;
|
||||||
return httplib::Server::HandlerResponse::Handled;
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req.has_header("Origin") && sparams.http_cors_origin != "*") {
|
// Check if the request is from any of the allowed origins
|
||||||
if (req.get_header_value("Origin") != sparams.http_cors_origin) {
|
if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){
|
||||||
|
if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) {
|
||||||
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
|
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
|
||||||
res.status = 403; // HTTP Forbidden
|
res.status = 403; // HTTP Forbidden
|
||||||
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
|
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
|
||||||
|
@ -2814,12 +2819,6 @@ int main(int argc, char **argv)
|
||||||
return httplib::Server::HandlerResponse::Unhandled;
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (sparams.http_cors_origin == "*"){
|
|
||||||
if (sparams.api_keys.size() == 0) {
|
|
||||||
LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
switch(current_state) {
|
switch(current_state) {
|
||||||
|
@ -3502,6 +3501,18 @@ int main(int argc, char **argv)
|
||||||
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
|
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
|
||||||
|
|
||||||
LOG_INFO("HTTP server listening", log_data);
|
LOG_INFO("HTTP server listening", log_data);
|
||||||
|
|
||||||
|
std::string cors_enabled_domains = "";
|
||||||
|
for (const auto &domain : sparams.http_cors_origin) {
|
||||||
|
cors_enabled_domains += domain + " ";
|
||||||
|
if (domain == "*") {
|
||||||
|
if (sparams.api_keys.size() == 0) {
|
||||||
|
LOG_WARNING("CORS requests are enabled for any request without an API key set. This is not recommended.", {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_INFO("CORS enabled domains", {{"domains", cors_enabled_domains}});
|
||||||
|
|
||||||
// run the HTTP server in a thread - see comment below
|
// run the HTTP server in a thread - see comment below
|
||||||
std::thread t([&]()
|
std::thread t([&]()
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue