This commit is contained in:
StrangeBytesDev 2024-05-17 15:10:32 +09:00 committed by GitHub
commit b4877e376e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 50 additions and 20 deletions

View file

@ -124,6 +124,7 @@ struct server_params {
std::string chat_template = "";
std::string system_prompt = "";
std::vector<std::string> http_cors_origin = {"http://localhost:8080", "http://127.0.0.1:8080"};
std::vector<std::string> api_keys;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@ -2387,6 +2388,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --http-cors-origin DOMAIN Set what origin (example.com) is allowed to access the API. Use * to allow all origins (insecure without --api-key). If you are using the server as an API from a browser, this parameter is required. Includes localhost and 127.0.0.1 by default.");
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
@ -2445,6 +2447,17 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break;
}
sparams.hostname = argv[i];
} else if (arg == "--http-cors-origin") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string cors_origins = argv[i];
std::stringstream ss(cors_origins);
std::string cors_origin;
while (std::getline(ss, cors_origin, ',')) {
sparams.http_cors_origin.push_back(cors_origin);
}
} else if (arg == "--path") {
if (++i >= argc) {
invalid_param = true;
@ -2963,13 +2976,30 @@ int main(int argc, char ** argv) {
svr->set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
// Disallow CORS requests from any origin not specified in the --http-cors-origin flag
svr->set_pre_routing_handler([&sparams](const httplib::Request &req, httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
res.set_header("Access-Control-Allow-Headers", "*");
// Allow options so that request will return a specific error message rather than a generic CORS error
if (req.method == "OPTIONS") {
res.status = 200;
return httplib::Server::HandlerResponse::Handled;
}
// Check if the request is from any of the allowed origins
if (req.has_header("Origin") && sparams.http_cors_origin[2] != "*"){
if (std::find(sparams.http_cors_origin.begin(), sparams.http_cors_origin.end(), req.get_header_value("Origin")) == sparams.http_cors_origin.end()) {
LOG_WARNING("Request from origin not allowed.", {{"origin", req.get_header_value("Origin")}});
res.status = 403; // HTTP Forbidden
res.set_content(R"({"error": "Origin is not allowed."})", "application/json");
return httplib::Server::HandlerResponse::Handled;
}
}
return httplib::Server::HandlerResponse::Unhandled;
});
svr->set_logger(log_server_request);

View file

@ -3,7 +3,7 @@
Feature: Security
Background: Server startup with an api key defined
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key llama.cpp
Then the server is starting
@ -60,9 +60,9 @@ Feature: Security
Then CORS header <cors_header> is set to <cors_header_value>
Examples: Headers
| origin | cors_header | cors_header_value |
| localhost | Access-Control-Allow-Origin | localhost |
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
| origin | Access-Control-Allow-Credentials | true |
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
| web.mydomain.fr | Access-Control-Allow-Headers | * |
| origin | cors_header | cors_header_value |
| localhost | Access-Control-Allow-Origin | localhost |
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
| origin | Access-Control-Allow-Credentials | true |
| web.mydomain.fr | Access-Control-Allow-Methods | GET, POST, OPTIONS |
| web.mydomain.fr | Access-Control-Allow-Headers | * |

View file

@ -3,7 +3,7 @@
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2

View file

@ -179,7 +179,7 @@ def step_start_server(context):
attempts += 1
if attempts > max_attempts:
assert False, "server not started"
print(f"waiting for server to start, connect error code = {result}...")
print(f"waiting for server to start on {context.server_fqdn}:{context.server_port}, connect error code = {result}...")
time.sleep(0.1)
@ -849,7 +849,7 @@ async def request_completion(prompt,
temperature=None):
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
origin = "http://localhost:8080"
headers = {
'Origin': origin
}
@ -927,7 +927,7 @@ async def oai_chat_completions(user_prompt,
}
}
if async_client:
origin = 'llama.cpp'
origin = "http://localhost:8080"
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}{base_path}',

View file

@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server
# to cap the number of tokens any completion request can generate
# or pass n_predict/max_tokens in the request.
Scenario: Infinite loop
Given a server listening on localhost:8080
Given a server listening on 127.0.0.1:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
# Uncomment below to fix the issue
#And 64 server max tokens to predict