diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7e0d068f8..504639ef5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -124,6 +124,7 @@ struct server_params { std::string chat_template = ""; std::string system_prompt = ""; + std::vector http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"}; std::vector api_keys; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -2387,6 +2388,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); + printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default."); printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); @@ -2445,6 +2447,17 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } sparams.hostname = argv[i]; + } else if (arg == "--http-cors-origin") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string cors_origins = argv[i]; + std::stringstream ss(cors_origins); + std::string cors_origin; + while (std::getline(ss, cors_origin, ',')) { + sparams.http_cors_origin.push_back(cors_origin); + } } else if (arg == "--path") { if (++i >= argc) { invalid_param = true; @@ -2962,16 +2975,33 @@ int main(int argc, char ** argv) { std::atomic state{SERVER_STATE_LOADING_MODEL}; svr->set_default_headers({{"Server", "llama.cpp"}}); - - // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + // Disallow CORS requests from any origin not specified in the --http-cors-origin flag + svr->set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "application/json; charset=utf-8"); - }); + res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); + res.set_header("Access-Control-Allow-Headers", "*"); + // Allow options so that request will return a specific error message rather than a generic CORS error + if (req.method == "OPTIONS") { + res.status = 200; + return httplib::Server::HandlerResponse::Handled; + } + + // Check if the request is from any of the allowed origins + if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){ + if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) { + LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); + res.status = 403; // HTTP Forbidden + res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); + return httplib::Server::HandlerResponse::Handled; + } + } + + return httplib::Server::HandlerResponse::Unhandled; + }); + svr->set_logger(log_server_request); auto res_error = [](httplib::Response & res, json error_data) { diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index eb82e7aca..3d1c50b50 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -3,7 +3,7 @@ Feature: Security Background: Server startup with an api key defined - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a server api key llama.cpp Then the server is starting @@ -60,9 +60,9 @@ Feature: Security Then CORS header is set to Examples: Headers - | origin | cors_header | cors_header_value | - | localhost | Access-Control-Allow-Origin | localhost | - | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | - | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | POST | - | web.mydomain.fr | Access-Control-Allow-Headers | * | + | origin | cors_header | cors_header_value | + | localhost | Access-Control-Allow-Origin | localhost | + | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | + | origin | Access-Control-Allow-Credentials | true | + | web.mydomain.fr | Access-Control-Allow-Methods | GET, POST, OPTIONS | + | web.mydomain.fr | Access-Control-Allow-Headers | * | diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index d21c09135..db7153b28 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -3,7 +3,7 @@ Feature: llama.cpp server Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file test-model.gguf And a model alias tinyllama-2 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 577b87af3..4c029f9bf 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -179,7 +179,7 @@ def step_start_server(context): attempts += 1 if attempts > max_attempts: assert False, "server not started" - print(f"waiting for server to start, connect error code = {result}...") + print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...") time.sleep(0.1) @@ -849,7 +849,7 @@ async def request_completion(prompt, temperature=None): if debug: print(f"Sending completion request: {prompt}") - origin = "my.super.domain" + origin = "http://localhost:8080" headers = { 'Origin': origin } @@ -927,7 +927,7 @@ async def oai_chat_completions(user_prompt, } } if async_client: - origin = 'llama.cpp' + origin = "http://localhost:8080" headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}{base_path}', diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index cf14b3b44..9fc61c390 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server # to cap the number of tokens any completion request can generate # or pass n_predict/max_tokens in the request. Scenario: Infinite loop - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models # Uncomment below to fix the issue #And 64 server max tokens to predict