Merge 67e60c0da4
into 9c4fdcbec8
This commit is contained in:
commit
b4877e376e
5 changed files with 50 additions and 20 deletions
|
@ -124,6 +124,7 @@ struct server_params {
|
|||
std::string chat_template = "";
|
||||
std::string system_prompt = "";
|
||||
|
||||
std::vector<std::string> http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"};
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
@ -2387,6 +2388,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||
printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default.");
|
||||
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
|
||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||
|
@ -2445,6 +2447,17 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
break;
|
||||
}
|
||||
sparams.hostname = argv[i];
|
||||
} else if (arg == "--http-cors-origin") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string cors_origins = argv[i];
|
||||
std::stringstream ss(cors_origins);
|
||||
std::string cors_origin;
|
||||
while (std::getline(ss, cors_origin, ',')) {
|
||||
sparams.http_cors_origin.push_back(cors_origin);
|
||||
}
|
||||
} else if (arg == "--path") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2962,16 +2975,33 @@ int main(int argc, char ** argv) {
|
|||
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
||||
|
||||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||
|
||||
// CORS preflight
|
||||
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
// Disallow CORS requests from any origin not specified in the --http-cors-origin flag
|
||||
svr->set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
|
||||
// Allow options so that request will return a specific error message rather than a generic CORS error
|
||||
if (req.method == "OPTIONS") {
|
||||
res.status = 200;
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
|
||||
// Check if the request is from any of the allowed origins
|
||||
if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){
|
||||
if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) {
|
||||
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
|
||||
res.status = 403; // HTTP Forbidden
|
||||
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Feature: Security
|
||||
|
||||
Background: Server startup with an api key defined
|
||||
Given a server listening on localhost:8080
|
||||
Given a server listening on 127.0.0.1:8080
|
||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
And a server api key llama.cpp
|
||||
Then the server is starting
|
||||
|
@ -60,9 +60,9 @@ Feature: Security
|
|||
Then CORS header <cors_header> is set to <cors_header_value>
|
||||
|
||||
Examples: Headers
|
||||
| origin | cors_header | cors_header_value |
|
||||
| localhost | Access-Control-Allow-Origin | localhost |
|
||||
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
||||
| origin | Access-Control-Allow-Credentials | true |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||
| origin | cors_header | cors_header_value |
|
||||
| localhost | Access-Control-Allow-Origin | localhost |
|
||||
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
||||
| origin | Access-Control-Allow-Credentials | true |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | GET, POST, OPTIONS |
|
||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Feature: llama.cpp server
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
Given a server listening on 127.0.0.1:8080
|
||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
And a model file test-model.gguf
|
||||
And a model alias tinyllama-2
|
||||
|
|
|
@ -179,7 +179,7 @@ def step_start_server(context):
|
|||
attempts += 1
|
||||
if attempts > max_attempts:
|
||||
assert False, "server not started"
|
||||
print(f"waiting for server to start, connect error code = {result}...")
|
||||
print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...")
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
|
@ -849,7 +849,7 @@ async def request_completion(prompt,
|
|||
temperature=None):
|
||||
if debug:
|
||||
print(f"Sending completion request: {prompt}")
|
||||
origin = "my.super.domain"
|
||||
origin = "http://localhost:8080"
|
||||
headers = {
|
||||
'Origin': origin
|
||||
}
|
||||
|
@ -927,7 +927,7 @@ async def oai_chat_completions(user_prompt,
|
|||
}
|
||||
}
|
||||
if async_client:
|
||||
origin = 'llama.cpp'
|
||||
origin = "http://localhost:8080"
|
||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}{base_path}',
|
||||
|
|
|
@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server
|
|||
# to cap the number of tokens any completion request can generate
|
||||
# or pass n_predict/max_tokens in the request.
|
||||
Scenario: Infinite loop
|
||||
Given a server listening on localhost:8080
|
||||
Given a server listening on 127.0.0.1:8080
|
||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
# Uncomment below to fix the issue
|
||||
#And 64 server max tokens to predict
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue