From 6baa61c1e066664410e166ed75cd57c62736f46e Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Wed, 28 Feb 2024 13:36:17 -0800 Subject: [PATCH 1/6] Enable CORS requests on all routes --- examples/server/server.cpp | 15 +++------------ examples/server/tests/features/security.feature | 2 +- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 080fa9bd5..0df3cdd90 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2821,11 +2821,11 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); - // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { + // Allow CORS requests on all routes + svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Methods", "*"); res.set_header("Access-Control-Allow-Headers", "*"); }); @@ -3113,7 +3113,6 @@ int main(int argc, char **argv) svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() }, @@ -3125,7 +3124,6 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3202,7 +3200,6 @@ int main(int argc, char **argv) svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); json models = { @@ -3222,7 +3219,6 @@ int main(int argc, char **argv) const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3305,7 +3301,6 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3375,7 +3370,6 @@ int main(int argc, char **argv) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; if (body.count("content") != 0) @@ -3388,7 +3382,6 @@ int main(int argc, char **argv) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) @@ -3403,7 +3396,6 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; if (body.count("content") != 0) @@ -3439,7 +3431,6 @@ int main(int argc, char **argv) svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index db06d3977..3ae413f99 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -46,5 +46,5 @@ Feature: Security | localhost | Access-Control-Allow-Origin | localhost | | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | POST | + | web.mydomain.fr | Access-Control-Allow-Methods | * | | web.mydomain.fr | Access-Control-Allow-Headers | * | From 1e6a2f12c6453d7b5158b37c8a789fd3934af044 Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Sat, 2 Mar 2024 17:31:57 -0800 Subject: [PATCH 2/6] Add --public-domain flag to server to enable CORS requests. * Disable CORS requests by default. * Add --public-domain flag to allow specifying a CORS allowed domain. * Warn about using "*" without an API key. --- examples/server/README.md | 1 + examples/server/server.cpp | 45 ++++++++++++++++--- .../server/tests/features/parallel.feature | 2 +- .../server/tests/features/security.feature | 2 +- examples/server/tests/features/server.feature | 2 +- examples/server/tests/features/steps/steps.py | 8 ++-- .../tests/features/wrong_usages.feature | 2 +- 7 files changed, 49 insertions(+), 13 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 397ee8252..9f3a85bce 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -42,6 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--port`: Set the port to listen. Default: `8080`. +- `--public-domain`: Set a public domain which will be allowed for Cross Origin Requests. If you are using the server as an API from a browser, this is required. - `--path`: path from which to serve static files (default examples/server/public) - `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys. - `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 99f759c63..6009eaaeb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -34,9 +34,10 @@ using json = nlohmann::json; struct server_params { - std::string hostname = "127.0.0.1"; + std::string hostname = "localhost"; std::vector api_keys; std::string public_path = "examples/server/public"; + std::string public_domain = "http://localhost:8080"; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2073,6 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2132,6 +2134,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.hostname = argv[i]; } + else if (arg == "--public-domain") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.public_domain = argv[i]; + } else if (arg == "--path") { if (++i >= argc) @@ -2777,14 +2788,38 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); - // Allow CORS requests on all routes - svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) { + // Disallow CORS requests from any origin not specified in the --public-domain flag + LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}}); + svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "*"); res.set_header("Access-Control-Allow-Headers", "*"); + + // Allow options so that request will return a specific error message rather than a generic CORS error + if (req.method == "OPTIONS") { + res.status = 200; + return httplib::Server::HandlerResponse::Handled; + } + + if (req.has_header("Origin") && sparams.public_domain != "*") { + if (req.get_header_value("Origin") != sparams.public_domain) { + LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); + res.status = 403; // HTTP Forbidden + res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); + return httplib::Server::HandlerResponse::Handled; + } + } + + return httplib::Server::HandlerResponse::Unhandled; }); + if (sparams.public_domain == "*"){ + if (sparams.api_keys.size() == 0) { + LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); + } + } + svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { @@ -3069,7 +3104,7 @@ int main(int argc, char **argv) return false; }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request &, httplib::Response &res) { json data = { { "user_name", llama.name_user.c_str() }, @@ -3156,7 +3191,7 @@ int main(int argc, char **argv) } }); - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res) + svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request&, httplib::Response& res) { std::time_t t = std::time(0); diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 86cdf7282..4e023929c 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -3,7 +3,7 @@ Feature: Parallel Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And 42 as server seed And 512 as batch size diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index d5eaddce3..0049e1a25 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -3,7 +3,7 @@ Feature: Security Background: Server startup with an api key defined - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a server api key llama.cpp Then the server is starting diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 7c977bcce..fd2c5078c 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -3,7 +3,7 @@ Feature: llama.cpp server Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model alias tinyllama-2 And 42 as server seed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 319527802..aa5870da6 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -122,7 +122,7 @@ def step_start_server(context): attempts += 1 if attempts > 20: assert False, "server not started" - print(f"waiting for server to start, connect error code = {result}...") + print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...") time.sleep(0.1) @@ -609,7 +609,7 @@ async def request_completion(prompt, user_api_key=None): if debug: print(f"Sending completion request: {prompt}") - origin = "my.super.domain" + origin = "http://localhost:8080" headers = { 'Origin': origin } @@ -678,7 +678,7 @@ async def oai_chat_completions(user_prompt, } } if async_client: - origin = 'llama.cpp' + origin = "http://localhost:8080" headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}{base_path}', @@ -774,7 +774,7 @@ async def request_oai_embeddings(input, # openai client always expects an api_key user_api_key = user_api_key if user_api_key is not None else 'nope' if async_client: - origin = 'llama.cpp' + origin = "http://localhost:8080" if user_api_key is not None: headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index cf14b3b44..9fc61c390 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server # to cap the number of tokens any completion request can generate # or pass n_predict/max_tokens in the request. Scenario: Infinite loop - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models # Uncomment below to fix the issue #And 64 server max tokens to predict From 320e89c0105be96a229bb1c3da070a5ad0c2f931 Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Sun, 3 Mar 2024 13:00:33 -0800 Subject: [PATCH 3/6] Change allowed methods. Rename cors flag. * Restrict HTTP requests to GET, POST, and OPTIONS. * rename cors flag from "--public-domain" to "--http-cors-origin" --- examples/server/server.cpp | 20 +++++++++---------- .../server/tests/features/security.feature | 12 +++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 16d60b968..d7769a7c4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -34,10 +34,10 @@ using json = nlohmann::json; struct server_params { - std::string hostname = "localhost"; + std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string public_domain = "http://localhost:8080"; + std::string http_cors_origin = "http://localhost:8080"; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str()); + printf(" --http-cors-origin DOMAIN a domain which will allow cross origin requests. (default: %s)\n", sparams.http_cors_origin.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2134,14 +2134,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.hostname = argv[i]; } - else if (arg == "--public-domain") + else if (arg == "--http-cors-origin") { if (++i >= argc) { invalid_param = true; break; } - sparams.public_domain = argv[i]; + sparams.http_cors_origin = argv[i]; } else if (arg == "--path") { @@ -2789,11 +2789,11 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); // Disallow CORS requests from any origin not specified in the --public-domain flag - LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}}); + LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}}); svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "*"); + res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); res.set_header("Access-Control-Allow-Headers", "*"); // Allow options so that request will return a specific error message rather than a generic CORS error @@ -2802,8 +2802,8 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Handled; } - if (req.has_header("Origin") && sparams.public_domain != "*") { - if (req.get_header_value("Origin") != sparams.public_domain) { + if (req.has_header("Origin") && sparams.http_cors_origin != "*") { + if (req.get_header_value("Origin") != sparams.http_cors_origin) { LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); res.status = 403; // HTTP Forbidden res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); @@ -2814,7 +2814,7 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Unhandled; }); - if (sparams.public_domain == "*"){ + if (sparams.http_cors_origin == "*"){ if (sparams.api_keys.size() == 0) { LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); } diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index 0049e1a25..9b5f4f101 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -43,9 +43,9 @@ Feature: Security Then CORS header is set to Examples: Headers - | origin | cors_header | cors_header_value | - | localhost | Access-Control-Allow-Origin | localhost | - | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | - | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | * | - | web.mydomain.fr | Access-Control-Allow-Headers | * | + | origin | cors_header | cors_header_value | + | localhost | Access-Control-Allow-Origin | localhost | + | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | + | origin | Access-Control-Allow-Credentials | true | + | web.mydomain.fr | Access-Control-Allow-Methods | GET, POST, OPTIONS | + | web.mydomain.fr | Access-Control-Allow-Headers | * | From 2dc03f404d4b2460956074244466df01e76838ec Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Sun, 3 Mar 2024 14:09:41 -0800 Subject: [PATCH 4/6] Update --http-cors-origin flag description --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d7769a7c4..cf5f30d61 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --http-cors-origin DOMAIN a domain which will allow cross origin requests. (default: %s)\n", sparams.http_cors_origin.c_str()); + printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. (default: %s)\n", sparams.http_cors_origin.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); From f4ca2301bbfe44b45c1d5b38bcb0d422ba2037af Mon Sep 17 00:00:00 2001 From: StrangebytesDev Date: Sun, 3 Mar 2024 15:29:32 -0800 Subject: [PATCH 5/6] Allow setting multiple CORS enabled origins. * Allow setting multiple CORS enabled origins. * Add both "http://localhost:8080" and "http://127.0.0.1:8080" by default. * Move CORS logging below server startup to make it more visible. --- examples/server/server.cpp | 39 ++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cf5f30d61..6294e7bc6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -37,7 +37,7 @@ struct server_params { std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string http_cors_origin = "http://localhost:8080"; + std::vector http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"}; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2074,7 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. (default: %s)\n", sparams.http_cors_origin.c_str()); + printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default."); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2141,7 +2141,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - sparams.http_cors_origin = argv[i]; + std::string cors_origins = argv[i]; + std::stringstream ss(cors_origins); + std::string cors_origin; + while (std::getline(ss, cors_origin, ',')) { + sparams.http_cors_origin.push_back(cors_origin); + } } else if (arg == "--path") { @@ -2788,8 +2793,7 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); - // Disallow CORS requests from any origin not specified in the --public-domain flag - LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.http_cors_origin}}); + // Disallow CORS requests from any origin not specified in the --http-cors-origin flag svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); @@ -2801,9 +2805,10 @@ int main(int argc, char **argv) res.status = 200; return httplib::Server::HandlerResponse::Handled; } - - if (req.has_header("Origin") && sparams.http_cors_origin != "*") { - if (req.get_header_value("Origin") != sparams.http_cors_origin) { + + // Check if the request is from any of the allowed origins + if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){ + if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) { LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); res.status = 403; // HTTP Forbidden res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); @@ -2814,12 +2819,6 @@ int main(int argc, char **argv) return httplib::Server::HandlerResponse::Unhandled; }); - if (sparams.http_cors_origin == "*"){ - if (sparams.api_keys.size() == 0) { - LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); - } - } - svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { @@ -3502,6 +3501,18 @@ int main(int argc, char **argv) svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); + + std::string cors_enabled_domains = ""; + for (const auto &domain : sparams.http_cors_origin) { + cors_enabled_domains += domain + " "; + if (domain == "*") { + if (sparams.api_keys.size() == 0) { + LOG_WARNING("CORS requests are enabled for any request without an API key set. This is not recommended.", {}); + } + } + } + LOG_INFO("CORS enabled domains", {{"domains", cors_enabled_domains}}); + // run the HTTP server in a thread - see comment below std::thread t([&]() { From 67e60c0da4512f0a6f3d1c76448c783bf2c92aa4 Mon Sep 17 00:00:00 2001 From: StrangeBytesDev Date: Fri, 10 May 2024 09:39:14 -0700 Subject: [PATCH 6/6] Basic fix to cors check --- examples/server/server.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a76c0cded..22d21453d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -124,6 +124,7 @@ struct server_params { std::string chat_template = ""; std::string system_prompt = ""; + std::vector http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"}; std::vector api_keys; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -2972,7 +2973,7 @@ int main(int argc, char ** argv) { svr->set_default_headers({{"Server", "llama.cpp"}}); // Disallow CORS requests from any origin not specified in the --http-cors-origin flag - svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { + svr->set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS");