From 320e89c0105be96a229bb1c3da070a5ad0c2f931 Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Sun, 3 Mar 2024 13:00:33 -0800 Subject: [PATCH] Change allowed methods. Rename cors flag. * Restrict HTTP requests to GET, POST, and OPTIONS. * rename cors flag from "--public-domain" to "--http-cors-origin" --- examples/server/server.cpp | 20 +++++++++---------- .../server/tests/features/security.feature | 12 +++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 16d60b968..d7769a7c4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -34,10 +34,10 @@ using json = nlohmann::json; struct server_params { - std::string hostname = "localhost"; + std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string public_domain = "http://localhost:8080"; + std::string http_cors_origin = "http://localhost:8080"; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str()); + printf(" --http-cors-origin DOMAIN a domain which will allow cross origin requests. (default: %s)\n", sparams.http_cors_origin.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2134,14 +2134,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.hostname = argv[i]; } - else if (arg == "--public-domain") + else if (arg == "--http-cors-origin") { if (++i >= argc) { invalid_param = true; break; } - sparams.public_domain = argv[i]; + sparams.http_cors_origin = argv[i]; } else if (arg == "--path") { @@ -2789,11 +2789,11 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); // Disallow CORS requests from any origin not specified in the --public-domain flag - LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}}); + LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}}); svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "*"); + res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); res.set_header("Access-Control-Allow-Headers", "*"); // Allow options so that request will return a specific error message rather than a generic CORS error @@ -2802,8 +2802,8 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Handled; } - if (req.has_header("Origin") && sparams.public_domain != "*") { - if (req.get_header_value("Origin") != sparams.public_domain) { + if (req.has_header("Origin") && sparams.http_cors_origin != "*") { + if (req.get_header_value("Origin") != sparams.http_cors_origin) { LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); res.status = 403; // HTTP Forbidden res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); @@ -2814,7 +2814,7 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Unhandled; }); - if (sparams.public_domain == "*"){ + if (sparams.http_cors_origin == "*"){ if (sparams.api_keys.size() == 0) { LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); } diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index 0049e1a25..9b5f4f101 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -43,9 +43,9 @@ Feature: Security Then CORS header is set to Examples: Headers - | origin | cors_header | cors_header_value | - | localhost | Access-Control-Allow-Origin | localhost | - | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | - | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | * | - | web.mydomain.fr | Access-Control-Allow-Headers | * | + | origin | cors_header | cors_header_value | + | localhost | Access-Control-Allow-Origin | localhost | + | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | + | origin | Access-Control-Allow-Credentials | true | + | web.mydomain.fr | Access-Control-Allow-Methods | GET, POST, OPTIONS | + | web.mydomain.fr | Access-Control-Allow-Headers | * |