diff --git a/examples/server/README.md b/examples/server/README.md index 397ee8252..9f3a85bce 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -42,6 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--port`: Set the port to listen. Default: `8080`. +- `--public-domain`: Set a public domain which will be allowed for Cross Origin Requests. If you are using the server as an API from a browser, this is required. - `--path`: path from which to serve static files (default examples/server/public) - `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys. - `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 99f759c63..6009eaaeb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -34,9 +34,10 @@ using json = nlohmann::json; struct server_params { - std::string hostname = "127.0.0.1"; + std::string hostname = "localhost"; std::vector api_keys; std::string public_path = "examples/server/public"; + std::string public_domain = "http://localhost:8080"; std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; @@ -2073,6 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); @@ -2132,6 +2134,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.hostname = argv[i]; } + else if (arg == "--public-domain") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.public_domain = argv[i]; + } else if (arg == "--path") { if (++i >= argc) @@ -2777,14 +2788,38 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); - // Allow CORS requests on all routes - svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) { + // Disallow CORS requests from any origin not specified in the --public-domain flag + LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}}); + svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "*"); res.set_header("Access-Control-Allow-Headers", "*"); + + // Allow options so that request will return a specific error message rather than a generic CORS error + if (req.method == "OPTIONS") { + res.status = 200; + return httplib::Server::HandlerResponse::Handled; + } + + if (req.has_header("Origin") && sparams.public_domain != "*") { + if (req.get_header_value("Origin") != sparams.public_domain) { + LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}}); + res.status = 403; // HTTP Forbidden + res.set_content(R"({"error": "Origin is not allowed."})", "application/json"); + return httplib::Server::HandlerResponse::Handled; + } + } + + return httplib::Server::HandlerResponse::Unhandled; }); + if (sparams.public_domain == "*"){ + if (sparams.api_keys.size() == 0) { + LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {}); + } + } + svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { @@ -3069,7 +3104,7 @@ int main(int argc, char **argv) return false; }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request &, httplib::Response &res) { json data = { { "user_name", llama.name_user.c_str() }, @@ -3156,7 +3191,7 @@ int main(int argc, char **argv) } }); - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res) + svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request&, httplib::Response& res) { std::time_t t = std::time(0); diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 86cdf7282..4e023929c 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -3,7 +3,7 @@ Feature: Parallel Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And 42 as server seed And 512 as batch size diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index d5eaddce3..0049e1a25 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -3,7 +3,7 @@ Feature: Security Background: Server startup with an api key defined - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a server api key llama.cpp Then the server is starting diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 7c977bcce..fd2c5078c 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -3,7 +3,7 @@ Feature: llama.cpp server Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model alias tinyllama-2 And 42 as server seed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 319527802..aa5870da6 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -122,7 +122,7 @@ def step_start_server(context): attempts += 1 if attempts > 20: assert False, "server not started" - print(f"waiting for server to start, connect error code = {result}...") + print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...") time.sleep(0.1) @@ -609,7 +609,7 @@ async def request_completion(prompt, user_api_key=None): if debug: print(f"Sending completion request: {prompt}") - origin = "my.super.domain" + origin = "http://localhost:8080" headers = { 'Origin': origin } @@ -678,7 +678,7 @@ async def oai_chat_completions(user_prompt, } } if async_client: - origin = 'llama.cpp' + origin = "http://localhost:8080" headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}{base_path}', @@ -774,7 +774,7 @@ async def request_oai_embeddings(input, # openai client always expects an api_key user_api_key = user_api_key if user_api_key is not None else 'nope' if async_client: - origin = 'llama.cpp' + origin = "http://localhost:8080" if user_api_key is not None: headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index cf14b3b44..9fc61c390 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server # to cap the number of tokens any completion request can generate # or pass n_predict/max_tokens in the request. Scenario: Infinite loop - Given a server listening on localhost:8080 + Given a server listening on 127.0.0.1:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models # Uncomment below to fix the issue #And 64 server max tokens to predict