server: added support for multiple api keys, added loading api keys from file

This commit is contained in:
Michael Coppola 2024-01-10 14:17:17 -05:00
parent 57d016ba2d
commit df7ab297b8

View file

@ -38,7 +38,8 @@ using json = nlohmann::json;
struct server_params struct server_params
{ {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string api_key; std::string api_key_file;
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
@ -2014,6 +2015,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
@ -2074,7 +2076,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.api_key = argv[i]; sparams.api_keys.push_back(argv[i]);
}
else if (arg == "--api-key-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.api_key_file = argv[i];
} }
else if (arg == "--timeout" || arg == "-to") else if (arg == "--timeout" || arg == "-to")
{ {
@ -2773,6 +2784,18 @@ int main(int argc, char **argv)
server_params_parse(argc, argv, sparams, params, llama); server_params_parse(argc, argv, sparams, params, llama);
// load api keys from file
if (!sparams.api_key_file.empty()) {
std::ifstream key_file(sparams.api_key_file);
std::string key;
while (std::getline(key_file, key)) {
if (key.size() > 0) {
sparams.api_keys.push_back(key);
}
}
key_file.close();
}
if (params.model_alias == "unknown") if (params.model_alias == "unknown")
{ {
params.model_alias = params.model; params.model_alias = params.model;
@ -2803,7 +2826,7 @@ int main(int argc, char **argv)
// Middleware for API key validation // Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation // If API key is not set, skip validation
if (sparams.api_key.empty()) { if (sparams.api_keys.empty()) {
return true; return true;
} }
@ -2812,7 +2835,7 @@ int main(int argc, char **argv)
std::string prefix = "Bearer "; std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) { if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size()); std::string received_api_key = auth_header.substr(prefix.size());
if (received_api_key == sparams.api_key) { if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid return true; // API key is valid
} }
} }
@ -3216,10 +3239,13 @@ int main(int argc, char **argv)
log_data["hostname"] = sparams.hostname; log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port); log_data["port"] = std::to_string(sparams.port);
if (!sparams.api_key.empty()) { if (sparams.api_keys.size() == 1) {
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
} else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + (sparams.api_key_file.empty() ? "" : sparams.api_key_file + " ")+ "(" + std::to_string(sparams.api_keys.size()) + " keys loaded)";
} }
LOG_INFO("HTTP server listening", log_data); LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below // run the HTTP server in a thread - see comment below
std::thread t([&]() std::thread t([&]()