Allow setting multiple CORS enabled origins.

* Allow setting multiple CORS enabled origins.
* Add both "http://localhost:8080" and "http://127.0.0.1:8080" by default.
* Move CORS logging below server startup to make it more visible.
This commit is contained in:
StrangebytesDev 2024-03-03 15:29:32 -08:00
parent 2dc03f404d
commit f4ca2301bb

View file

@ -37,7 +37,7 @@ struct server_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string http_cors_origin = "http://localhost:8080"; std::vector<std::string> http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"};
std::string chat_template = ""; std::string chat_template = "";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. (default: %s)\n", sparams.http_cors_origin.c_str()); printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default.");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
@ -2141,7 +2141,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.http_cors_origin = argv[i]; std::string cors_origins = argv[i];
std::stringstream ss(cors_origins);
std::string cors_origin;
while (std::getline(ss, cors_origin, ',')) {
sparams.http_cors_origin.push_back(cors_origin);
}
} }
else if (arg == "--path") else if (arg == "--path")
{ {
@ -2788,8 +2793,7 @@ int main(int argc, char **argv)
svr.set_default_headers({{"Server", "llama.cpp"}}); svr.set_default_headers({{"Server", "llama.cpp"}});
// Disallow CORS requests from any origin not specified in the --public-domain flag // Disallow CORS requests from any origin not specified in the --http-cors-origin flag
LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}});
svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Credentials", "true");
@ -2802,8 +2806,9 @@ int main(int argc, char **argv)
return httplib::Server::HandlerResponse::Handled; return httplib::Server::HandlerResponse::Handled;
} }
if (req.has_header("Origin") && sparams.http_cors_origin != "*") { // Check if the request is from any of the allowed origins
if (req.get_header_value("Origin") != sparams.http_cors_origin) { if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){
if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) {
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
res.status = 403; // HTTP Forbidden res.status = 403; // HTTP Forbidden
res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
@ -2814,12 +2819,6 @@ int main(int argc, char **argv)
return httplib::Server::HandlerResponse::Unhandled; return httplib::Server::HandlerResponse::Unhandled;
}); });
if (sparams.http_cors_origin == "*"){
if (sparams.api_keys.size() == 0) {
LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {});
}
}
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load(); server_state current_state = state.load();
switch(current_state) { switch(current_state) {
@ -3502,6 +3501,18 @@ int main(int argc, char **argv)
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
LOG_INFO("HTTP server listening", log_data); LOG_INFO("HTTP server listening", log_data);
std::string cors_enabled_domains = "";
for (const auto &domain : sparams.http_cors_origin) {
cors_enabled_domains += domain + " ";
if (domain == "*") {
if (sparams.api_keys.size() == 0) {
LOG_WARNING("CORS requests are enabled for any request without an API key set. This is not recommended.", {});
}
}
}
LOG_INFO("CORS enabled domains", {{"domains", cors_enabled_domains}});
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()
{ {