Improved apikey code

This commit is contained in:
pudepiedj 2024-02-27 12:38:16 +00:00
parent 02702d975d
commit 5854b0b86d
5 changed files with 101 additions and 54 deletions

View file

@ -98,8 +98,8 @@ if __name__ == "__main__":
url = "http://192.168.1.31:8080/completion"
num_requests = 20
q = Queue(maxsize = 64)
num_requests = 76
q = Queue(maxsize = 80)
threads = []
bar = make_empty_bar(num_requests)

View file

@ -1,7 +1,9 @@
john123456
susan987654
guestabcdef
fred123123
george890890
sandra234234
tilly567567
{
"john":["john123456","john0001"],
"susan":["susan987654","susan0001"],
"guest":["guestabcdef","guest0001"],
"fred": ["fred123123","fred0001"],
"george":["george890890","george0001"],
"sandra":["sandra234234","sandra0001"],
"tilly":["tilly567567","tilly0001"]
}

View file

@ -96,7 +96,7 @@
// the value here (8u, 16u, 32u, etc) is what governs max threads at 5126
#ifndef CPPHTTPLIB_THREAD_POOL_COUNT
#define CPPHTTPLIB_THREAD_POOL_COUNT \
((std::max)(64u, std::thread::hardware_concurrency() > 0 \
((std::max)(128u, std::thread::hardware_concurrency() > 0 \
? std::thread::hardware_concurrency() - 1 \
: 0))
#endif

View file

@ -38,7 +38,7 @@ using json = nlohmann::json;
struct server_params
{
std::string hostname = "127.0.0.1"; // --host switches to use 0.0.0.0 for public network.
std::vector<std::string> api_keys;
std::map<std::string, std::vector<std::string>> api_keys; // store for improved api_keys database
std::string public_path = "examples/server/public";
std::string chat_template = "";
int32_t port = 8080;
@ -313,14 +313,14 @@ struct llama_client_slot
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
t_prompt_processing, num_prompt_tokens_processed,
t_token, n_tokens_second);
/*LOG_INFO(buffer, {
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_prompt_processing", t_prompt_processing},
{"num_prompt_tokens_processed", num_prompt_tokens_processed},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});*/
});
t_token = t_token_generation / n_decoded;
n_tokens_second = 1e3 / t_token_generation * n_decoded;
@ -328,24 +328,24 @@ struct llama_client_slot
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
t_token_generation, n_decoded,
t_token, n_tokens_second);
/*LOG_INFO(buffer, {
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_token_generation", t_token_generation},
{"n_decoded", n_decoded},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});*/
});
printf("\033[5;0H]");
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
/*LOG_INFO(buffer, {
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_prompt_processing", t_prompt_processing},
{"t_token_generation", t_token_generation},
{"t_total", t_prompt_processing + t_token_generation},
});*/
});
}
};
@ -571,10 +571,10 @@ struct llama_server_context
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
/*LOG_INFO("new slot", {
LOG_INFO("new slot", {
{"slot_id", slot.id},
{"n_ctx_slot", slot.n_ctx}
});*/
});
const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;
@ -585,11 +585,11 @@ struct llama_server_context
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
/*LOG_INFO("slot self-extend", {
LOG_INFO("slot self-extend", {
{"slot_id", slot.id},
{"ga_n", ga_n},
{"ga_w", ga_w}
});*/
});
}
slot.ga_i = 0;
@ -966,10 +966,10 @@ struct llama_server_context
all_slots_are_idle = false;
/*LOG_INFO("slot is processing task", {
LOG_INFO("slot is processing task", {
{"slot_id", slot->id},
{"task_id", slot->task_id},
});*/
});
return true;
}
@ -1634,11 +1634,11 @@ struct llama_server_context
}
slots_data.push_back(slot_data);
}
/*LOG_INFO("slot data", {
LOG_INFO("slot data", {
{"task_id", task.id},
{"n_idle_slots", n_idle_slots},
{"n_processing_slots", n_processing_slots}
});*/
});
LOG_VERBOSE("slot data", {
{"task_id", task.id},
{"n_idle_slots", n_idle_slots},
@ -1706,7 +1706,7 @@ struct llama_server_context
{
if (system_prompt.empty() && clean_kv_cache)
{
/*LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});*/
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
kv_cache_clear();
}
return true;
@ -1731,7 +1731,7 @@ struct llama_server_context
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
const int n_discard = n_left / 2;
/*LOG_INFO("slot context shift", {
LOG_INFO("slot context shift", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_keep", n_keep},
@ -1741,7 +1741,7 @@ struct llama_server_context
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});*/
});
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
@ -1770,7 +1770,7 @@ struct llama_server_context
slot.command = NONE;
slot.t_last_used = ggml_time_us();
/*LOG_INFO("slot released", {
LOG_INFO("slot released", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_ctx", n_ctx},
@ -1778,7 +1778,7 @@ struct llama_server_context
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});*/
});
queue_tasks.notify_slot_changed();
continue;
@ -1934,12 +1934,12 @@ struct llama_server_context
slot.ga_i = ga_i;
}
/*LOG_INFO("slot progression", {
LOG_INFO("slot progression", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
});*/
});
}
slot.cache_tokens = prompt_tokens;
@ -1959,11 +1959,11 @@ struct llama_server_context
}
int p0 = (int) system_tokens.size() + slot.n_past;
/*LOG_INFO("kv cache rm [p0, end)", {
LOG_INFO("kv cache rm [p0, end)", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "p0", p0 }
});*/
});
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
LOG_VERBOSE("prompt ingested", {
@ -2258,6 +2258,30 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n");
}
static std::map<std::string, std::vector<std::string>> get_userdata(std::string file) {
// Map to store user apikey records {username: {apikey, usercodename}}
using Record = std::map<std::string, std::vector<std::string>>;
Record records;
std::ifstream infile(file);
nlohmann::json data;
infile >> data;
for(auto it = data.begin(); it != data.end(); ++it) {
//nlohmann::json obj = it.value();
std::string username = it.key();
std::vector<std::string> info = it.value();
records[username] = info;
}
return records;
}
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context &llama)
{
@ -2296,6 +2320,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
sparams.public_path = argv[i];
}
/*
else if (arg == "--api-key")
{
if (++i >= argc)
@ -2303,8 +2328,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true;
break;
}
sparams.api_keys.emplace_back(argv[i]);
std::string newuser = argv[i][0];
std::string newuserapi = argv[i][1];
std::string newusercode = argv[i][2]
sparams.api_keys.emplace_back({newuser: {newuserapi, newusercode}});
}
*/
else if (arg == "--api-key-file")
{
if (++i >= argc)
@ -2318,12 +2347,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true;
break;
}
std::string key;
while (std::getline(key_file, key)) {
if (key.size() > 0) {
sparams.api_keys.push_back(key);
}
}
sparams.api_keys = get_userdata(argv[i]);
key_file.close();
}
else if (arg == "--timeout" || arg == "-to")
@ -2859,14 +2884,14 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo
return;
}
/*LOG_INFO("request", {
LOG_INFO("request", {
{"remote_addr", req.remote_addr},
{"remote_port", req.remote_port},
{"status", res.status},
{"method", req.method},
{"path", req.path},
{"params", req.params},
});*/
});
LOG_VERBOSE("request", {
{"request", req.body},
@ -2923,7 +2948,7 @@ int main(int argc, char **argv)
llama_numa_init(params.numa);
ggml_time_init();
/*LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
{"commit", LLAMA_COMMIT}});
LOG_INFO("system info", {
@ -2931,7 +2956,7 @@ int main(int argc, char **argv)
{"n_threads_batch", params.n_threads_batch},
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});*/
});
httplib::Server svr;
@ -3154,12 +3179,16 @@ int main(int argc, char **argv)
log_data["port"] = std::to_string(sparams.port);
if (sparams.api_keys.size() == 1) { // what happens if the size is zero?
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0][0].substr(sparams.api_keys[0][0].length() - 4);
} else if (sparams.api_keys.size() > 1) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}
for (int i=0; i<int(sparams.api_keys.size()); i++) {
LOG_TEE("Loaded api key #%d: %s\n", i, sparams.api_keys[i].c_str());
for (auto &item : sparams.api_keys) {
std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1];
LOG_TEE("Loaded api key for user %s: %s with usercodename %s\n", username.c_str(), apikey.c_str(), usercode.c_str());
}
// load the model
@ -3191,6 +3220,7 @@ int main(int argc, char **argv)
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
LOG("Received API key = %s\n", received_api_key.c_str());
/*
for (int i = 0; i < int(sparams.api_keys.size()); i++) {
// for some reason the file apikeys are one character longer than those passed from Bearer so we shorten them
std::string uncut_api = sparams.api_keys[i]; // store original apikey
@ -3201,11 +3231,20 @@ int main(int argc, char **argv)
LOG("%s = %s Found matching api key.\n", received_api_key.c_str(), cut_api.c_str());
return true;
}
*/
for (auto& item : sparams.api_keys) {
std::string username = item.first;
std::string apikey = item.second[0];
std::string usercode = item.second[1]; // all three defined in anticipation of later use
if (received_api_key == apikey) {
LOG("Apikey found for user %s\n", username.c_str());
return true;
}
}
}
//if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
// return true; // API key is valid
//}
}
// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");

View file

@ -143,6 +143,9 @@ static inline void server_log(const char *level, const char *function, int line,
{"timestamp", time(nullptr)},
};
// freopen("/dev/null", "w", stdout);
freopen("/dev/null", "w", stderr);
if (server_log_json) {
log.merge_patch(
{
@ -155,10 +158,10 @@ static inline void server_log(const char *level, const char *function, int line,
log.merge_patch(extra);
}
std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
std::cerr << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
} else {
char buf[1024];
snprintf(buf, 1024, "\033[72;0H%4s [%24s] %s", level, function, message);
snprintf(buf, 1024, "\033[85;0H%4s [%24s] %s", level, function, message);
if (!extra.empty()) {
log.merge_patch(extra);
@ -168,13 +171,16 @@ static inline void server_log(const char *level, const char *function, int line,
for (const auto& el : log.items())
{
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
snprintf(buf, 1024, "\033[72;0H %s=%s", el.key().c_str(), value.c_str());
snprintf(buf, 1024, "\033[85;0H %s=%s", el.key().c_str(), value.c_str());
ss << buf;
}
const std::string str = ss.str();
printf("\033[72;0H%.*s\n", (int)str.size(), str.data());
fflush(stdout);
printf("\033[85;0H%.*s\n", (int)str.size(), str.data());
fflush(stderr);
// freopen("/dev/tty", "a", stdout);
// freopen("/dev/tty", "a", stderr);
}
}