Improved apikey code

This commit is contained in:
pudepiedj 2024-02-27 12:38:16 +00:00
parent 02702d975d
commit 5854b0b86d
5 changed files with 101 additions and 54 deletions

View file

@ -98,8 +98,8 @@ if __name__ == "__main__":
url = "http://192.168.1.31:8080/completion" url = "http://192.168.1.31:8080/completion"
num_requests = 20 num_requests = 76
q = Queue(maxsize = 64) q = Queue(maxsize = 80)
threads = [] threads = []
bar = make_empty_bar(num_requests) bar = make_empty_bar(num_requests)

View file

@ -1,7 +1,9 @@
john123456 {
susan987654 "john":["john123456","john0001"],
guestabcdef "susan":["susan987654","susan0001"],
fred123123 "guest":["guestabcdef","guest0001"],
george890890 "fred": ["fred123123","fred0001"],
sandra234234 "george":["george890890","george0001"],
tilly567567 "sandra":["sandra234234","sandra0001"],
"tilly":["tilly567567","tilly0001"]
}

View file

@ -96,7 +96,7 @@
// the value here (8u, 16u, 32u, etc) is what governs max threads at 5126 // the value here (8u, 16u, 32u, etc) is what governs max threads at 5126
#ifndef CPPHTTPLIB_THREAD_POOL_COUNT #ifndef CPPHTTPLIB_THREAD_POOL_COUNT
#define CPPHTTPLIB_THREAD_POOL_COUNT \ #define CPPHTTPLIB_THREAD_POOL_COUNT \
((std::max)(64u, std::thread::hardware_concurrency() > 0 \ ((std::max)(128u, std::thread::hardware_concurrency() > 0 \
? std::thread::hardware_concurrency() - 1 \ ? std::thread::hardware_concurrency() - 1 \
: 0)) : 0))
#endif #endif

View file

@ -38,7 +38,7 @@ using json = nlohmann::json;
struct server_params struct server_params
{ {
std::string hostname = "127.0.0.1"; // --host switches to use 0.0.0.0 for public network. std::string hostname = "127.0.0.1"; // --host switches to use 0.0.0.0 for public network.
std::vector<std::string> api_keys; std::map<std::string, std::vector<std::string>> api_keys; // store for improved api_keys database
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string chat_template = ""; std::string chat_template = "";
int32_t port = 8080; int32_t port = 8080;
@ -313,14 +313,14 @@ struct llama_client_slot
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing, num_prompt_tokens_processed,
t_token, n_tokens_second); t_token, n_tokens_second);
/*LOG_INFO(buffer, { LOG_INFO(buffer, {
{"slot_id", id}, {"slot_id", id},
{"task_id", task_id}, {"task_id", task_id},
{"t_prompt_processing", t_prompt_processing}, {"t_prompt_processing", t_prompt_processing},
{"num_prompt_tokens_processed", num_prompt_tokens_processed}, {"num_prompt_tokens_processed", num_prompt_tokens_processed},
{"t_token", t_token}, {"t_token", t_token},
{"n_tokens_second", n_tokens_second}, {"n_tokens_second", n_tokens_second},
});*/ });
t_token = t_token_generation / n_decoded; t_token = t_token_generation / n_decoded;
n_tokens_second = 1e3 / t_token_generation * n_decoded; n_tokens_second = 1e3 / t_token_generation * n_decoded;
@ -328,24 +328,24 @@ struct llama_client_slot
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
t_token_generation, n_decoded, t_token_generation, n_decoded,
t_token, n_tokens_second); t_token, n_tokens_second);
/*LOG_INFO(buffer, { LOG_INFO(buffer, {
{"slot_id", id}, {"slot_id", id},
{"task_id", task_id}, {"task_id", task_id},
{"t_token_generation", t_token_generation}, {"t_token_generation", t_token_generation},
{"n_decoded", n_decoded}, {"n_decoded", n_decoded},
{"t_token", t_token}, {"t_token", t_token},
{"n_tokens_second", n_tokens_second}, {"n_tokens_second", n_tokens_second},
});*/ });
printf("\033[5;0H]"); printf("\033[5;0H]");
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation); sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
/*LOG_INFO(buffer, { LOG_INFO(buffer, {
{"slot_id", id}, {"slot_id", id},
{"task_id", task_id}, {"task_id", task_id},
{"t_prompt_processing", t_prompt_processing}, {"t_prompt_processing", t_prompt_processing},
{"t_token_generation", t_token_generation}, {"t_token_generation", t_token_generation},
{"t_total", t_prompt_processing + t_token_generation}, {"t_total", t_prompt_processing + t_token_generation},
});*/ });
} }
}; };
@ -571,10 +571,10 @@ struct llama_server_context
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict; slot.n_predict = params.n_predict;
/*LOG_INFO("new slot", { LOG_INFO("new slot", {
{"slot_id", slot.id}, {"slot_id", slot.id},
{"n_ctx_slot", slot.n_ctx} {"n_ctx_slot", slot.n_ctx}
});*/ });
const int ga_n = params.grp_attn_n; const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w; const int ga_w = params.grp_attn_w;
@ -585,11 +585,11 @@ struct llama_server_context
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
/*LOG_INFO("slot self-extend", { LOG_INFO("slot self-extend", {
{"slot_id", slot.id}, {"slot_id", slot.id},
{"ga_n", ga_n}, {"ga_n", ga_n},
{"ga_w", ga_w} {"ga_w", ga_w}
});*/ });
} }
slot.ga_i = 0; slot.ga_i = 0;
@ -966,10 +966,10 @@ struct llama_server_context
all_slots_are_idle = false; all_slots_are_idle = false;
/*LOG_INFO("slot is processing task", { LOG_INFO("slot is processing task", {
{"slot_id", slot->id}, {"slot_id", slot->id},
{"task_id", slot->task_id}, {"task_id", slot->task_id},
});*/ });
return true; return true;
} }
@ -1634,11 +1634,11 @@ struct llama_server_context
} }
slots_data.push_back(slot_data); slots_data.push_back(slot_data);
} }
/*LOG_INFO("slot data", { LOG_INFO("slot data", {
{"task_id", task.id}, {"task_id", task.id},
{"n_idle_slots", n_idle_slots}, {"n_idle_slots", n_idle_slots},
{"n_processing_slots", n_processing_slots} {"n_processing_slots", n_processing_slots}
});*/ });
LOG_VERBOSE("slot data", { LOG_VERBOSE("slot data", {
{"task_id", task.id}, {"task_id", task.id},
{"n_idle_slots", n_idle_slots}, {"n_idle_slots", n_idle_slots},
@ -1706,7 +1706,7 @@ struct llama_server_context
{ {
if (system_prompt.empty() && clean_kv_cache) if (system_prompt.empty() && clean_kv_cache)
{ {
/*LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});*/ LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
kv_cache_clear(); kv_cache_clear();
} }
return true; return true;
@ -1731,7 +1731,7 @@ struct llama_server_context
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
const int n_discard = n_left / 2; const int n_discard = n_left / 2;
/*LOG_INFO("slot context shift", { LOG_INFO("slot context shift", {
{"slot_id", slot.id}, {"slot_id", slot.id},
{"task_id", slot.task_id}, {"task_id", slot.task_id},
{"n_keep", n_keep}, {"n_keep", n_keep},
@ -1741,7 +1741,7 @@ struct llama_server_context
{"n_past", slot.n_past}, {"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()}, {"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()} {"n_cache_tokens", slot.cache_tokens.size()}
});*/ });
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
@ -1770,7 +1770,7 @@ struct llama_server_context
slot.command = NONE; slot.command = NONE;
slot.t_last_used = ggml_time_us(); slot.t_last_used = ggml_time_us();
/*LOG_INFO("slot released", { LOG_INFO("slot released", {
{"slot_id", slot.id}, {"slot_id", slot.id},
{"task_id", slot.task_id}, {"task_id", slot.task_id},
{"n_ctx", n_ctx}, {"n_ctx", n_ctx},
@ -1778,7 +1778,7 @@ struct llama_server_context
{"n_system_tokens", system_tokens.size()}, {"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated} {"truncated", slot.truncated}
});*/ });
queue_tasks.notify_slot_changed(); queue_tasks.notify_slot_changed();
continue; continue;
@ -1934,12 +1934,12 @@ struct llama_server_context
slot.ga_i = ga_i; slot.ga_i = ga_i;
} }
/*LOG_INFO("slot progression", { LOG_INFO("slot progression", {
{ "slot_id", slot.id }, { "slot_id", slot.id },
{ "task_id", slot.task_id }, { "task_id", slot.task_id },
{ "n_past", slot.n_past }, { "n_past", slot.n_past },
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed } { "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
});*/ });
} }
slot.cache_tokens = prompt_tokens; slot.cache_tokens = prompt_tokens;
@ -1959,11 +1959,11 @@ struct llama_server_context
} }
int p0 = (int) system_tokens.size() + slot.n_past; int p0 = (int) system_tokens.size() + slot.n_past;
/*LOG_INFO("kv cache rm [p0, end)", { LOG_INFO("kv cache rm [p0, end)", {
{ "slot_id", slot.id }, { "slot_id", slot.id },
{ "task_id", slot.task_id }, { "task_id", slot.task_id },
{ "p0", p0 } { "p0", p0 }
});*/ });
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
LOG_VERBOSE("prompt ingested", { LOG_VERBOSE("prompt ingested", {
@ -2258,6 +2258,30 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n"); printf("\n");
} }
static std::map<std::string, std::vector<std::string>> get_userdata(std::string file) {
// Map to store user apikey records {username: {apikey, usercodename}}
using Record = std::map<std::string, std::vector<std::string>>;
Record records;
std::ifstream infile(file);
nlohmann::json data;
infile >> data;
for(auto it = data.begin(); it != data.end(); ++it) {
//nlohmann::json obj = it.value();
std::string username = it.key();
std::vector<std::string> info = it.value();
records[username] = info;
}
return records;
}
static void server_params_parse(int argc, char **argv, server_params &sparams, static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context &llama) gpt_params &params, llama_server_context &llama)
{ {
@ -2296,6 +2320,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
sparams.public_path = argv[i]; sparams.public_path = argv[i];
} }
/*
else if (arg == "--api-key") else if (arg == "--api-key")
{ {
if (++i >= argc) if (++i >= argc)
@ -2303,8 +2328,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.api_keys.emplace_back(argv[i]); std::string newuser = argv[i][0];
std::string newuserapi = argv[i][1];
std::string newusercode = argv[i][2]
sparams.api_keys.emplace_back({newuser: {newuserapi, newusercode}});
} }
*/
else if (arg == "--api-key-file") else if (arg == "--api-key-file")
{ {
if (++i >= argc) if (++i >= argc)
@ -2318,12 +2347,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
std::string key; sparams.api_keys = get_userdata(argv[i]);
while (std::getline(key_file, key)) {
if (key.size() > 0) {
sparams.api_keys.push_back(key);
}
}
key_file.close(); key_file.close();
} }
else if (arg == "--timeout" || arg == "-to") else if (arg == "--timeout" || arg == "-to")
@ -2859,14 +2884,14 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo
return; return;
} }
/*LOG_INFO("request", { LOG_INFO("request", {
{"remote_addr", req.remote_addr}, {"remote_addr", req.remote_addr},
{"remote_port", req.remote_port}, {"remote_port", req.remote_port},
{"status", res.status}, {"status", res.status},
{"method", req.method}, {"method", req.method},
{"path", req.path}, {"path", req.path},
{"params", req.params}, {"params", req.params},
});*/ });
LOG_VERBOSE("request", { LOG_VERBOSE("request", {
{"request", req.body}, {"request", req.body},
@ -2923,7 +2948,7 @@ int main(int argc, char **argv)
llama_numa_init(params.numa); llama_numa_init(params.numa);
ggml_time_init(); ggml_time_init();
/*LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
{"commit", LLAMA_COMMIT}}); {"commit", LLAMA_COMMIT}});
LOG_INFO("system info", { LOG_INFO("system info", {
@ -2931,7 +2956,7 @@ int main(int argc, char **argv)
{"n_threads_batch", params.n_threads_batch}, {"n_threads_batch", params.n_threads_batch},
{"total_threads", std::thread::hardware_concurrency()}, {"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()}, {"system_info", llama_print_system_info()},
});*/ });
httplib::Server svr; httplib::Server svr;
@ -3154,12 +3179,16 @@ int main(int argc, char **argv)
log_data["port"] = std::to_string(sparams.port); log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) { // what happens if the size is zero? if (sparams.api_keys.size() == 1) { // what happens if the size is zero?
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); log_data["api_key"] = "api_key: ****" + sparams.api_keys[0][0].substr(sparams.api_keys[0][0].length() - 4);
} else if (sparams.api_keys.size() > 1) { } else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
} }
for (int i=0; i<int(sparams.api_keys.size()); i++) { for (auto &item : sparams.api_keys) {
LOG_TEE("Loaded api key #%d: %s\n", i, sparams.api_keys[i].c_str()); std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1];
LOG_TEE("Loaded api key for user %s: %s with usercodename %s\n", username.c_str(), apikey.c_str(), usercode.c_str());
} }
// load the model // load the model
@ -3191,6 +3220,7 @@ int main(int argc, char **argv)
if (auth_header.substr(0, prefix.size()) == prefix) { if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size()); std::string received_api_key = auth_header.substr(prefix.size());
LOG("Received API key = %s\n", received_api_key.c_str()); LOG("Received API key = %s\n", received_api_key.c_str());
/*
for (int i = 0; i < int(sparams.api_keys.size()); i++) { for (int i = 0; i < int(sparams.api_keys.size()); i++) {
// for some reason the file apikeys are one character longer than those passed from Bearer so we shorten them // for some reason the file apikeys are one character longer than those passed from Bearer so we shorten them
std::string uncut_api = sparams.api_keys[i]; // store original apikey std::string uncut_api = sparams.api_keys[i]; // store original apikey
@ -3201,11 +3231,20 @@ int main(int argc, char **argv)
LOG("%s = %s Found matching api key.\n", received_api_key.c_str(), cut_api.c_str()); LOG("%s = %s Found matching api key.\n", received_api_key.c_str(), cut_api.c_str());
return true; return true;
} }
*/
for (auto& item : sparams.api_keys) {
std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1]; // all three defined in anticipation of later use
if (received_api_key == apikey) {
LOG("Apikey found for user %s\n", username.c_str());
return true;
}
}
} }
//if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { //if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
// return true; // API key is valid // return true; // API key is valid
//} //}
}
// API key is invalid or not provided // API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");

View file

@ -143,6 +143,9 @@ static inline void server_log(const char *level, const char *function, int line,
{"timestamp", time(nullptr)}, {"timestamp", time(nullptr)},
}; };
// freopen("/dev/null", "w", stdout);
freopen("/dev/null", "w", stderr);
if (server_log_json) { if (server_log_json) {
log.merge_patch( log.merge_patch(
{ {
@ -155,10 +158,10 @@ static inline void server_log(const char *level, const char *function, int line,
log.merge_patch(extra); log.merge_patch(extra);
} }
std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush; std::cerr << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
} else { } else {
char buf[1024]; char buf[1024];
snprintf(buf, 1024, "\033[72;0H%4s [%24s] %s", level, function, message); snprintf(buf, 1024, "\033[85;0H%4s [%24s] %s", level, function, message);
if (!extra.empty()) { if (!extra.empty()) {
log.merge_patch(extra); log.merge_patch(extra);
@ -168,13 +171,16 @@ static inline void server_log(const char *level, const char *function, int line,
for (const auto& el : log.items()) for (const auto& el : log.items())
{ {
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
snprintf(buf, 1024, "\033[72;0H %s=%s", el.key().c_str(), value.c_str()); snprintf(buf, 1024, "\033[85;0H %s=%s", el.key().c_str(), value.c_str());
ss << buf; ss << buf;
} }
const std::string str = ss.str(); const std::string str = ss.str();
printf("\033[72;0H%.*s\n", (int)str.size(), str.data()); printf("\033[85;0H%.*s\n", (int)str.size(), str.data());
fflush(stdout); fflush(stderr);
// freopen("/dev/tty", "a", stdout);
// freopen("/dev/tty", "a", stderr);
} }
} }