Add --public-domain flag to server to enable CORS requests.
* Disable CORS requests by default. * Add --public-domain flag to allow specifying a CORS allowed domain. * Warn about using "*" without an API key.
This commit is contained in:
parent
ab7a989293
commit
1e6a2f12c6
7 changed files with 49 additions and 13 deletions
|
@ -42,6 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
|
||||||
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
|
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
|
||||||
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
||||||
- `--port`: Set the port to listen. Default: `8080`.
|
- `--port`: Set the port to listen. Default: `8080`.
|
||||||
|
- `--public-domain`: Set a public domain which will be allowed for Cross Origin Requests. If you are using the server as an API from a browser, this is required.
|
||||||
- `--path`: path from which to serve static files (default examples/server/public)
|
- `--path`: path from which to serve static files (default examples/server/public)
|
||||||
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
||||||
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
||||||
|
|
|
@ -34,9 +34,10 @@
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
struct server_params {
|
struct server_params {
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "localhost";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
|
std::string public_domain = "http://localhost:8080";
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
|
@ -2073,6 +2074,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||||
|
printf(" --public-domain DOMAIN a public domain to allow cross origin requests from (default: %s)\n", sparams.public_domain.c_str());
|
||||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
|
@ -2132,6 +2134,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
sparams.hostname = argv[i];
|
sparams.hostname = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--public-domain")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.public_domain = argv[i];
|
||||||
|
}
|
||||||
else if (arg == "--path")
|
else if (arg == "--path")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -2777,14 +2788,38 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
svr.set_default_headers({{"Server", "llama.cpp"}});
|
svr.set_default_headers({{"Server", "llama.cpp"}});
|
||||||
|
|
||||||
// Allow CORS requests on all routes
|
// Disallow CORS requests from any origin not specified in the --public-domain flag
|
||||||
svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) {
|
LOG_INFO("CORS requests enabled on domain:", {{"domain", sparams.public_domain}});
|
||||||
|
svr.set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
res.set_header("Access-Control-Allow-Methods", "*");
|
res.set_header("Access-Control-Allow-Methods", "*");
|
||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
|
|
||||||
|
// Allow options so that request will return a specific error message rather than a generic CORS error
|
||||||
|
if (req.method == "OPTIONS") {
|
||||||
|
res.status = 200;
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.has_header("Origin") && sparams.public_domain != "*") {
|
||||||
|
if (req.get_header_value("Origin") != sparams.public_domain) {
|
||||||
|
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
|
||||||
|
res.status = 403; // HTTP Forbidden
|
||||||
|
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (sparams.public_domain == "*"){
|
||||||
|
if (sparams.api_keys.size() == 0) {
|
||||||
|
LOG_WARNING("Public domain is set to * without an API key specified. This is not recommended.", {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
switch(current_state) {
|
switch(current_state) {
|
||||||
|
@ -3069,7 +3104,7 @@ int main(int argc, char **argv)
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
|
svr.Get("/props", [&llama](const httplib::Request &, httplib::Response &res)
|
||||||
{
|
{
|
||||||
json data = {
|
json data = {
|
||||||
{ "user_name", llama.name_user.c_str() },
|
{ "user_name", llama.name_user.c_str() },
|
||||||
|
@ -3156,7 +3191,7 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res)
|
svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request&, httplib::Response& res)
|
||||||
{
|
{
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
Feature: Parallel
|
Feature: Parallel
|
||||||
|
|
||||||
Background: Server startup
|
Background: Server startup
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on 127.0.0.1:8080
|
||||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
And 42 as server seed
|
And 42 as server seed
|
||||||
And 512 as batch size
|
And 512 as batch size
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
Feature: Security
|
Feature: Security
|
||||||
|
|
||||||
Background: Server startup with an api key defined
|
Background: Server startup with an api key defined
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on 127.0.0.1:8080
|
||||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
And a server api key llama.cpp
|
And a server api key llama.cpp
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
Feature: llama.cpp server
|
Feature: llama.cpp server
|
||||||
|
|
||||||
Background: Server startup
|
Background: Server startup
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on 127.0.0.1:8080
|
||||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
And a model alias tinyllama-2
|
And a model alias tinyllama-2
|
||||||
And 42 as server seed
|
And 42 as server seed
|
||||||
|
|
|
@ -122,7 +122,7 @@ def step_start_server(context):
|
||||||
attempts += 1
|
attempts += 1
|
||||||
if attempts > 20:
|
if attempts > 20:
|
||||||
assert False, "server not started"
|
assert False, "server not started"
|
||||||
print(f"waiting for server to start, connect error code = {result}...")
|
print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...")
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -609,7 +609,7 @@ async def request_completion(prompt,
|
||||||
user_api_key=None):
|
user_api_key=None):
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Sending completion request: {prompt}")
|
print(f"Sending completion request: {prompt}")
|
||||||
origin = "my.super.domain"
|
origin = "http://localhost:8080"
|
||||||
headers = {
|
headers = {
|
||||||
'Origin': origin
|
'Origin': origin
|
||||||
}
|
}
|
||||||
|
@ -678,7 +678,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if async_client:
|
if async_client:
|
||||||
origin = 'llama.cpp'
|
origin = "http://localhost:8080"
|
||||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}{base_path}',
|
async with session.post(f'{base_url}{base_path}',
|
||||||
|
@ -774,7 +774,7 @@ async def request_oai_embeddings(input,
|
||||||
# openai client always expects an api_key
|
# openai client always expects an api_key
|
||||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
if async_client:
|
if async_client:
|
||||||
origin = 'llama.cpp'
|
origin = "http://localhost:8080"
|
||||||
if user_api_key is not None:
|
if user_api_key is not None:
|
||||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
|
@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server
|
||||||
# to cap the number of tokens any completion request can generate
|
# to cap the number of tokens any completion request can generate
|
||||||
# or pass n_predict/max_tokens in the request.
|
# or pass n_predict/max_tokens in the request.
|
||||||
Scenario: Infinite loop
|
Scenario: Infinite loop
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on 127.0.0.1:8080
|
||||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
# Uncomment below to fix the issue
|
# Uncomment below to fix the issue
|
||||||
#And 64 server max tokens to predict
|
#And 64 server max tokens to predict
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue