diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f98843d85..0deceefa8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -248,7 +248,7 @@ struct server_slot { } } - json get_formated_timings() const { + json get_formatted_timings() const { return json { {"prompt_n", n_prompt_tokens_processed}, {"prompt_ms", t_prompt_processing}, @@ -1157,7 +1157,7 @@ struct server_context { return slot.has_next_token; // continue } - json get_formated_generation(const server_slot & slot) const { + json get_formatted_generation(const server_slot & slot) const { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1469,7 +1469,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); + json slot_data = get_formatted_generation(slot); slot_data["id"] = slot.id; slot_data["id_task"] = slot.id_task; slot_data["state"] = slot.state; @@ -2205,7 +2205,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } sparams.api_keys.emplace_back(argv[i]); - } else if (arg == "--api-key-file") { + } else if (arg == "--api-key-file") { if (++i >= argc) { invalid_param = true; break; @@ -2216,8 +2216,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - sparams.api_keys = get_userdata(argv[i]); // read apikey json data - + std::string key; + while (std::getline(key_file, key)) { + if (key.size() > 0) { + sparams.api_keys.push_back(key); + } + } key_file.close(); } else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) { @@ -2895,17 +2899,10 @@ int main(int argc, char ** argv) { log_data["port"] = std::to_string(sparams.port); // process api keys - if (sparams.api_keys.size() == 1) { // should we trap what happens if the size is zero? - log_data["api_key"] = "api_key: ****" + sparams.api_keys[0][0].substr(sparams.api_keys[0][0].length() - 4); + if (sparams.api_keys.size() == 1) { + log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); } else if (sparams.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; // diagnostic; suppress eventually - } - for (auto &item : sparams.api_keys) { - std::string username = item.first; - std::string apikey = item.second[0]; - std::string usercode = item.second[1]; - - LOG_TEE("Loaded api key for user %s: %s with usercodename %s\n", username.c_str(), apikey.c_str(), usercode.c_str()); + log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } // load the model @@ -2930,29 +2927,21 @@ int main(int argc, char ** argv) { // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { - // If API key is not set, because the file is empty, skip validation + // If API key is not set, skip validation if (sparams.api_keys.empty()) { return true; } - // Check for API key in the header (TODO: need to add username eventually but ...) + // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); - LOG("Received API key = %s\n", received_api_key.c_str()); - - for (auto& item : sparams.api_keys) { - std::string username = item.first; - std::string apikey = item.second[0]; - std::string usercode = item.second[1]; // all three defined in anticipation of later use - if (received_api_key == apikey) { - LOG("Apikey found for user %s with username %s\n", username.c_str(), usercode.c_str()); - return true; - } - } + if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { + return true; // API key is valid } + } // API key is invalid or not provided res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");