diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cf5f30d61..6294e7bc6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -37,7 +37,7 @@ struct server_params { std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string http_cors_origin = "http://localhost:8080"; + std::vector http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"}; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. (default: %s)\n", sparams.http_cors_origin.c_str()); + printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default."); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2141,7 +2141,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - sparams.http_cors_origin = argv[i]; + std::string cors_origins = argv[i]; + std::stringstream ss(cors_origins); + std::string cors_origin; + while (std::getline(ss, cors_origin, ',')) { + sparams.http_cors_origin.push_back(cors_origin); + } } else if (arg == "--path") { @@ -2788,8 +2793,7 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); - // Disallow CORS requests from any origin not specified in the --public-domain flag - LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}}); + // Disallow CORS requests from any origin not specified in the --http-cors-origin flag svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); @@ -2801,9 +2805,10 @@ int main(int argc, char **argv) res.status = 200; return httplib::Server::HandlerResponse::Handled; } - - if (req.has_header("Origin") && sparams.http_cors_origin != "*") { - if (req.get_header_value("Origin") != sparams.http_cors_origin) { + + // Check if the request is from any of the allowed origins + if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){ + if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) { LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); res.status = 403; // HTTP Forbidden res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); @@ -2814,12 +2819,6 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Unhandled; }); - if (sparams.http_cors_origin == "*"){ - if (sparams.api_keys.size() == 0) { - LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); - } - } - svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { @@ -3502,6 +3501,18 @@ int main(int argc, char **argv) svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); + + std::string cors_enabled_domains = ""; + for (const auto &domain : sparams.http_cors_origin) { + cors_enabled_domains += domain + " "; + if (domain == "*") { + if (sparams.api_keys.size() == 0) { + LOG_WARNING("CORS requests are enabled for any request without an API key set. This is not recommended.", {}); + } + } + } + LOG_INFO("CORS enabled domains", {{"domains", cors_enabled_domains}}); + // run the HTTP server in a thread - see comment below std::thread t([&]() {