server : to snake_case

This commit is contained in:
Georgi Gerganov 2023-12-15 13:47:14 +02:00
parent e4bf96329e
commit 81e67a218e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2682,24 +2682,28 @@ int main(int argc, char **argv)
httplib::Server svr;
// Middleware for API key validation
auto validateApiKey = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation
if (sparams.api_key.empty()) {
return true;
}
// Check for API key in the header
auto authHeader = req.get_header_value("Authorization");
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (authHeader.substr(0, prefix.size()) == prefix) {
std::string receivedApiKey = authHeader.substr(prefix.size());
if (receivedApiKey == sparams.api_key) {
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (received_api_key == sparams.api_key) {
return true; // API key is valid
}
}
// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain");
res.status = 401; // Unauthorized
LOG_WARNING("Unauthorized: Invalid API Key", {});
return false;
};
@ -2745,9 +2749,9 @@ int main(int argc, char **argv)
res.set_content(data.dump(), "application/json");
});
svr.Post("/completion", [&llama, &validateApiKey](const httplib::Request &req, httplib::Response &res)
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
if (!validateApiKey(req, res)) {
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body);
@ -2836,9 +2840,9 @@ int main(int argc, char **argv)
});
// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validateApiKey](const httplib::Request &req, httplib::Response &res)
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
if (!validateApiKey(req, res)) {
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(json::parse(req.body));
@ -2909,9 +2913,9 @@ int main(int argc, char **argv)
}
});
svr.Post("/infill", [&llama, &validateApiKey](const httplib::Request &req, httplib::Response &res)
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
if (!validateApiKey(req, res)) {
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body);
@ -3079,15 +3083,15 @@ int main(int argc, char **argv)
// to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
std::unordered_map<std::string, std::string> logData;
logData["hostname"] = sparams.hostname;
logData["port"] = std::to_string(sparams.port);
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (!sparams.api_key.empty()) {
logData["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
}
LOG_INFO("HTTP server listening", logData);
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{