Add --public-domain flag to server to enable CORS requests.

* Disable CORS requests by default.
* Add --public-domain flag to allow specifying a CORS allowed domain.
* Warn about using "*" without an API key.
This commit is contained in:
StrangebytesDev 2024-03-02 17:31:57 -08:00
parent ab7a989293
commit 1e6a2f12c6
7 changed files with 49 additions and 13 deletions

View file

@ -42,6 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
- `--public-domain`: Set a public domain which will be allowed for Cross Origin Requests. If you are using the server as an API from a browser, this is required.
- `--path`: path from which to serve static files (default examples/server/public)
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.

View file

@ -34,9 +34,10 @@
using json = nlohmann::json;
struct server_params {
std::string hostname = "127.0.0.1";
std::string hostname = "localhost";
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public";
std::string public_domain = "http://localhost:8080";
std::string chat_template = "";
int32_t port = 8080;
int32_t read_timeout = 600;
@ -2073,6 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
@ -2132,6 +2134,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
sparams.hostname = argv[i];
}
else if (arg == "--public-domain")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.public_domain = argv[i];
}
else if (arg == "--path")
{
if (++i >= argc)
@ -2777,14 +2788,38 @@ int main(int argc, char **argv)
svr.set_default_headers({{"Server", "llama.cpp"}});
// Allow CORS requests on all routes
svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) {
// Disallow CORS requests from any origin not specified in the --public-domain flag
LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}});
svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*");
// Allow options so that request will return a specific error message rather than a generic CORS error
if (req.method == "OPTIONS") {
res.status = 200;
return httplib::Server::HandlerResponse::Handled;
}
if (req.has_header("Origin") && sparams.public_domain != "*") {
if (req.get_header_value("Origin") != sparams.public_domain) {
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
res.status = 403; // HTTP Forbidden
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
return httplib::Server::HandlerResponse::Handled;
}
}
return httplib::Server::HandlerResponse::Unhandled;
});
if (sparams.public_domain == "*"){
if (sparams.api_keys.size() == 0) {
LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {});
}
}
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load();
switch(current_state) {
@ -3069,7 +3104,7 @@ int main(int argc, char **argv)
return false;
});
svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
svr.Get("/props", [&llama](const httplib::Request &, httplib::Response &res)
{
json data = {
{ "user_name", llama.name_user.c_str() },
@ -3156,7 +3191,7 @@ int main(int argc, char **argv)
}
});
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request& req, httplib::Response& res)
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request&, httplib::Response& res)
{
std::time_t t = std::time(0);

View file

@ -3,7 +3,7 @@
Feature: Parallel
Background: Server startup
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed
And 512 as batch size

View file

@ -3,7 +3,7 @@
Feature: Security
Background: Server startup with an api key defined
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key llama.cpp
Then the server is starting

View file

@ -3,7 +3,7 @@
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model alias tinyllama-2
And 42 as server seed

View file

@ -122,7 +122,7 @@ def step_start_server(context):
attempts += 1
if attempts > 20:
assert False, "server not started"
print(f"waiting for server to start, connect error code = {result}...")
print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...")
time.sleep(0.1)
@ -609,7 +609,7 @@ async def request_completion(prompt,
user_api_key=None):
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
origin = "http://localhost:8080"
headers = {
'Origin': origin
}
@ -678,7 +678,7 @@ async def oai_chat_completions(user_prompt,
}
}
if async_client:
origin = 'llama.cpp'
origin = "http://localhost:8080"
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}{base_path}',
@ -774,7 +774,7 @@ async def request_oai_embeddings(input,
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
origin = 'llama.cpp'
origin = "http://localhost:8080"
if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:

View file

@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server
# to cap the number of tokens any completion request can generate
# or pass n_predict/max_tokens in the request.
Scenario: Infinite loop
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
# Uncomment below to fix the issue
#And 64 server max tokens to predict