Resolve merge conflicts in server

This commit is contained in:
pudepiedj 2024-03-08 13:38:16 +00:00
parent a6d2611624
commit 97ff2abc0e

View file

@ -248,7 +248,7 @@ struct server_slot {
} }
} }
json get_formated_timings() const { json get_formatted_timings() const {
return json { return json {
{"prompt_n", n_prompt_tokens_processed}, {"prompt_n", n_prompt_tokens_processed},
{"prompt_ms", t_prompt_processing}, {"prompt_ms", t_prompt_processing},
@ -1157,7 +1157,7 @@ struct server_context {
return slot.has_next_token; // continue return slot.has_next_token; // continue
} }
json get_formated_generation(const server_slot & slot) const { json get_formatted_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -1469,7 +1469,7 @@ struct server_context {
int n_processing_slots = 0; int n_processing_slots = 0;
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
json slot_data = get_formated_generation(slot); json slot_data = get_formatted_generation(slot);
slot_data["id"] = slot.id; slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task; slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state; slot_data["state"] = slot.state;
@ -2216,8 +2216,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.api_keys = get_userdata(argv[i]); // read apikey json data std::string key;
while (std::getline(key_file, key)) {
if (key.size() > 0) {
sparams.api_keys.push_back(key);
}
}
key_file.close(); key_file.close();
} else if (arg == "--timeout" || arg == "-to") { } else if (arg == "--timeout" || arg == "-to") {
if (++i >= argc) { if (++i >= argc) {
@ -2895,17 +2899,10 @@ int main(int argc, char ** argv) {
log_data["port"] = std::to_string(sparams.port); log_data["port"] = std::to_string(sparams.port);
// process api keys // process api keys
if (sparams.api_keys.size() == 1) { // should we trap what happens if the size is zero? if (sparams.api_keys.size() == 1) {
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0][0].substr(sparams.api_keys[0][0].length() - 4); log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
} else if (sparams.api_keys.size() > 1) { } else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; // diagnostic; suppress eventually log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}
for (auto &item : sparams.api_keys) {
std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1];
LOG_TEE("Loaded api key for user %s: %s with usercodename %s\n", username.c_str(), apikey.c_str(), usercode.c_str());
} }
// load the model // load the model
@ -2930,27 +2927,19 @@ int main(int argc, char ** argv) {
// Middleware for API key validation // Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, because the file is empty, skip validation // If API key is not set, skip validation
if (sparams.api_keys.empty()) { if (sparams.api_keys.empty()) {
return true; return true;
} }
// Check for API key in the header (TODO: need to add username eventually but ...) // Check for API key in the header
auto auth_header = req.get_header_value("Authorization"); auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer "; std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) { if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size()); std::string received_api_key = auth_header.substr(prefix.size());
LOG("Received API key = %s\n", received_api_key.c_str()); if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid
for (auto& item : sparams.api_keys) {
std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1]; // all three defined in anticipation of later use
if (received_api_key == apikey) {
LOG("Apikey found for user %s with username %s\n", username.c_str(), usercode.c_str());
return true;
}
} }
} }