Improved apikey code
This commit is contained in:
parent
02702d975d
commit
5854b0b86d
5 changed files with 101 additions and 54 deletions
|
@ -98,8 +98,8 @@ if __name__ == "__main__":
|
|||
|
||||
url = "http://192.168.1.31:8080/completion"
|
||||
|
||||
num_requests = 20
|
||||
q = Queue(maxsize = 64)
|
||||
num_requests = 76
|
||||
q = Queue(maxsize = 80)
|
||||
threads = []
|
||||
|
||||
bar = make_empty_bar(num_requests)
|
||||
|
|
16
apikeys.txt
16
apikeys.txt
|
@ -1,7 +1,9 @@
|
|||
john123456
|
||||
susan987654
|
||||
guestabcdef
|
||||
fred123123
|
||||
george890890
|
||||
sandra234234
|
||||
tilly567567
|
||||
{
|
||||
"john":["john123456","john0001"],
|
||||
"susan":["susan987654","susan0001"],
|
||||
"guest":["guestabcdef","guest0001"],
|
||||
"fred": ["fred123123","fred0001"],
|
||||
"george":["george890890","george0001"],
|
||||
"sandra":["sandra234234","sandra0001"],
|
||||
"tilly":["tilly567567","tilly0001"]
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@
|
|||
// the value here (8u, 16u, 32u, etc) is what governs max threads at 5126
|
||||
#ifndef CPPHTTPLIB_THREAD_POOL_COUNT
|
||||
#define CPPHTTPLIB_THREAD_POOL_COUNT \
|
||||
((std::max)(64u, std::thread::hardware_concurrency() > 0 \
|
||||
((std::max)(128u, std::thread::hardware_concurrency() > 0 \
|
||||
? std::thread::hardware_concurrency() - 1 \
|
||||
: 0))
|
||||
#endif
|
||||
|
|
|
@ -38,7 +38,7 @@ using json = nlohmann::json;
|
|||
struct server_params
|
||||
{
|
||||
std::string hostname = "127.0.0.1"; // --host switches to use 0.0.0.0 for public network.
|
||||
std::vector<std::string> api_keys;
|
||||
std::map<std::string, std::vector<std::string>> api_keys; // store for improved api_keys database
|
||||
std::string public_path = "examples/server/public";
|
||||
std::string chat_template = "";
|
||||
int32_t port = 8080;
|
||||
|
@ -313,14 +313,14 @@ struct llama_client_slot
|
|||
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_prompt_processing, num_prompt_tokens_processed,
|
||||
t_token, n_tokens_second);
|
||||
/*LOG_INFO(buffer, {
|
||||
LOG_INFO(buffer, {
|
||||
{"slot_id", id},
|
||||
{"task_id", task_id},
|
||||
{"t_prompt_processing", t_prompt_processing},
|
||||
{"num_prompt_tokens_processed", num_prompt_tokens_processed},
|
||||
{"t_token", t_token},
|
||||
{"n_tokens_second", n_tokens_second},
|
||||
});*/
|
||||
});
|
||||
|
||||
t_token = t_token_generation / n_decoded;
|
||||
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
||||
|
@ -328,24 +328,24 @@ struct llama_client_slot
|
|||
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||
t_token_generation, n_decoded,
|
||||
t_token, n_tokens_second);
|
||||
/*LOG_INFO(buffer, {
|
||||
LOG_INFO(buffer, {
|
||||
{"slot_id", id},
|
||||
{"task_id", task_id},
|
||||
{"t_token_generation", t_token_generation},
|
||||
{"n_decoded", n_decoded},
|
||||
{"t_token", t_token},
|
||||
{"n_tokens_second", n_tokens_second},
|
||||
});*/
|
||||
});
|
||||
|
||||
printf("\033[5;0H]");
|
||||
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||
/*LOG_INFO(buffer, {
|
||||
LOG_INFO(buffer, {
|
||||
{"slot_id", id},
|
||||
{"task_id", task_id},
|
||||
{"t_prompt_processing", t_prompt_processing},
|
||||
{"t_token_generation", t_token_generation},
|
||||
{"t_total", t_prompt_processing + t_token_generation},
|
||||
});*/
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -571,10 +571,10 @@ struct llama_server_context
|
|||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params.n_predict;
|
||||
|
||||
/*LOG_INFO("new slot", {
|
||||
LOG_INFO("new slot", {
|
||||
{"slot_id", slot.id},
|
||||
{"n_ctx_slot", slot.n_ctx}
|
||||
});*/
|
||||
});
|
||||
|
||||
const int ga_n = params.grp_attn_n;
|
||||
const int ga_w = params.grp_attn_w;
|
||||
|
@ -585,11 +585,11 @@ struct llama_server_context
|
|||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
||||
|
||||
/*LOG_INFO("slot self-extend", {
|
||||
LOG_INFO("slot self-extend", {
|
||||
{"slot_id", slot.id},
|
||||
{"ga_n", ga_n},
|
||||
{"ga_w", ga_w}
|
||||
});*/
|
||||
});
|
||||
}
|
||||
|
||||
slot.ga_i = 0;
|
||||
|
@ -966,10 +966,10 @@ struct llama_server_context
|
|||
|
||||
all_slots_are_idle = false;
|
||||
|
||||
/*LOG_INFO("slot is processing task", {
|
||||
LOG_INFO("slot is processing task", {
|
||||
{"slot_id", slot->id},
|
||||
{"task_id", slot->task_id},
|
||||
});*/
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1634,11 +1634,11 @@ struct llama_server_context
|
|||
}
|
||||
slots_data.push_back(slot_data);
|
||||
}
|
||||
/*LOG_INFO("slot data", {
|
||||
LOG_INFO("slot data", {
|
||||
{"task_id", task.id},
|
||||
{"n_idle_slots", n_idle_slots},
|
||||
{"n_processing_slots", n_processing_slots}
|
||||
});*/
|
||||
});
|
||||
LOG_VERBOSE("slot data", {
|
||||
{"task_id", task.id},
|
||||
{"n_idle_slots", n_idle_slots},
|
||||
|
@ -1706,7 +1706,7 @@ struct llama_server_context
|
|||
{
|
||||
if (system_prompt.empty() && clean_kv_cache)
|
||||
{
|
||||
/*LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});*/
|
||||
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
|
||||
kv_cache_clear();
|
||||
}
|
||||
return true;
|
||||
|
@ -1731,7 +1731,7 @@ struct llama_server_context
|
|||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||
const int n_discard = n_left / 2;
|
||||
|
||||
/*LOG_INFO("slot context shift", {
|
||||
LOG_INFO("slot context shift", {
|
||||
{"slot_id", slot.id},
|
||||
{"task_id", slot.task_id},
|
||||
{"n_keep", n_keep},
|
||||
|
@ -1741,7 +1741,7 @@ struct llama_server_context
|
|||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||
});*/
|
||||
});
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
|
@ -1770,7 +1770,7 @@ struct llama_server_context
|
|||
slot.command = NONE;
|
||||
slot.t_last_used = ggml_time_us();
|
||||
|
||||
/*LOG_INFO("slot released", {
|
||||
LOG_INFO("slot released", {
|
||||
{"slot_id", slot.id},
|
||||
{"task_id", slot.task_id},
|
||||
{"n_ctx", n_ctx},
|
||||
|
@ -1778,7 +1778,7 @@ struct llama_server_context
|
|||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"truncated", slot.truncated}
|
||||
});*/
|
||||
});
|
||||
queue_tasks.notify_slot_changed();
|
||||
|
||||
continue;
|
||||
|
@ -1934,12 +1934,12 @@ struct llama_server_context
|
|||
slot.ga_i = ga_i;
|
||||
}
|
||||
|
||||
/*LOG_INFO("slot progression", {
|
||||
LOG_INFO("slot progression", {
|
||||
{ "slot_id", slot.id },
|
||||
{ "task_id", slot.task_id },
|
||||
{ "n_past", slot.n_past },
|
||||
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
|
||||
});*/
|
||||
});
|
||||
}
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
@ -1959,11 +1959,11 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
/*LOG_INFO("kv cache rm [p0, end)", {
|
||||
LOG_INFO("kv cache rm [p0, end)", {
|
||||
{ "slot_id", slot.id },
|
||||
{ "task_id", slot.task_id },
|
||||
{ "p0", p0 }
|
||||
});*/
|
||||
});
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
|
@ -2258,6 +2258,30 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf("\n");
|
||||
}
|
||||
|
||||
static std::map<std::string, std::vector<std::string>> get_userdata(std::string file) {
|
||||
|
||||
// Map to store user apikey records {username: {apikey, usercodename}}
|
||||
using Record = std::map<std::string, std::vector<std::string>>;
|
||||
|
||||
Record records;
|
||||
|
||||
std::ifstream infile(file);
|
||||
nlohmann::json data;
|
||||
infile >> data;
|
||||
|
||||
for(auto it = data.begin(); it != data.end(); ++it) {
|
||||
|
||||
//nlohmann::json obj = it.value();
|
||||
|
||||
std::string username = it.key();
|
||||
std::vector<std::string> info = it.value();
|
||||
|
||||
records[username] = info;
|
||||
}
|
||||
|
||||
return records;
|
||||
}
|
||||
|
||||
static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
gpt_params ¶ms, llama_server_context &llama)
|
||||
{
|
||||
|
@ -2296,6 +2320,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
sparams.public_path = argv[i];
|
||||
}
|
||||
/*
|
||||
else if (arg == "--api-key")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2303,8 +2328,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.api_keys.emplace_back(argv[i]);
|
||||
std::string newuser = argv[i][0];
|
||||
std::string newuserapi = argv[i][1];
|
||||
std::string newusercode = argv[i][2]
|
||||
sparams.api_keys.emplace_back({newuser: {newuserapi, newusercode}});
|
||||
}
|
||||
*/
|
||||
else if (arg == "--api-key-file")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2318,12 +2347,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (key.size() > 0) {
|
||||
sparams.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
sparams.api_keys = get_userdata(argv[i]);
|
||||
|
||||
key_file.close();
|
||||
}
|
||||
else if (arg == "--timeout" || arg == "-to")
|
||||
|
@ -2859,14 +2884,14 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo
|
|||
return;
|
||||
}
|
||||
|
||||
/*LOG_INFO("request", {
|
||||
LOG_INFO("request", {
|
||||
{"remote_addr", req.remote_addr},
|
||||
{"remote_port", req.remote_port},
|
||||
{"status", res.status},
|
||||
{"method", req.method},
|
||||
{"path", req.path},
|
||||
{"params", req.params},
|
||||
});*/
|
||||
});
|
||||
|
||||
LOG_VERBOSE("request", {
|
||||
{"request", req.body},
|
||||
|
@ -2923,7 +2948,7 @@ int main(int argc, char **argv)
|
|||
llama_numa_init(params.numa);
|
||||
ggml_time_init();
|
||||
|
||||
/*LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
|
||||
LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
|
||||
{"commit", LLAMA_COMMIT}});
|
||||
|
||||
LOG_INFO("system info", {
|
||||
|
@ -2931,7 +2956,7 @@ int main(int argc, char **argv)
|
|||
{"n_threads_batch", params.n_threads_batch},
|
||||
{"total_threads", std::thread::hardware_concurrency()},
|
||||
{"system_info", llama_print_system_info()},
|
||||
});*/
|
||||
});
|
||||
|
||||
httplib::Server svr;
|
||||
|
||||
|
@ -3154,12 +3179,16 @@ int main(int argc, char **argv)
|
|||
log_data["port"] = std::to_string(sparams.port);
|
||||
|
||||
if (sparams.api_keys.size() == 1) { // what happens if the size is zero?
|
||||
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
|
||||
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0][0].substr(sparams.api_keys[0][0].length() - 4);
|
||||
} else if (sparams.api_keys.size() > 1) {
|
||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
for (int i=0; i<int(sparams.api_keys.size()); i++) {
|
||||
LOG_TEE("Loaded api key #%d: %s\n", i, sparams.api_keys[i].c_str());
|
||||
for (auto &item : sparams.api_keys) {
|
||||
std::string username = item.first;
|
||||
std::string apikey = item.second[0];
|
||||
std::string usercode = item.second[1];
|
||||
|
||||
LOG_TEE("Loaded api key for user %s: %s with usercodename %s\n", username.c_str(), apikey.c_str(), usercode.c_str());
|
||||
}
|
||||
|
||||
// load the model
|
||||
|
@ -3191,6 +3220,7 @@ int main(int argc, char **argv)
|
|||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
LOG("Received API key = %s\n", received_api_key.c_str());
|
||||
/*
|
||||
for (int i = 0; i < int(sparams.api_keys.size()); i++) {
|
||||
// for some reason the file apikeys are one character longer than those passed from Bearer so we shorten them
|
||||
std::string uncut_api = sparams.api_keys[i]; // store original apikey
|
||||
|
@ -3201,11 +3231,20 @@ int main(int argc, char **argv)
|
|||
LOG("%s = %s Found matching api key.\n", received_api_key.c_str(), cut_api.c_str());
|
||||
return true;
|
||||
}
|
||||
*/
|
||||
for (auto& item : sparams.api_keys) {
|
||||
std::string username = item.first;
|
||||
std::string apikey = item.second[0];
|
||||
std::string usercode = item.second[1]; // all three defined in anticipation of later use
|
||||
if (received_api_key == apikey) {
|
||||
LOG("Apikey found for user %s\n", username.c_str());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
//if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
||||
// return true; // API key is valid
|
||||
//}
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
|
|
|
@ -143,6 +143,9 @@ static inline void server_log(const char *level, const char *function, int line,
|
|||
{"timestamp", time(nullptr)},
|
||||
};
|
||||
|
||||
// freopen("/dev/null", "w", stdout);
|
||||
freopen("/dev/null", "w", stderr);
|
||||
|
||||
if (server_log_json) {
|
||||
log.merge_patch(
|
||||
{
|
||||
|
@ -155,10 +158,10 @@ static inline void server_log(const char *level, const char *function, int line,
|
|||
log.merge_patch(extra);
|
||||
}
|
||||
|
||||
std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
|
||||
std::cerr << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush;
|
||||
} else {
|
||||
char buf[1024];
|
||||
snprintf(buf, 1024, "\033[72;0H%4s [%24s] %s", level, function, message);
|
||||
snprintf(buf, 1024, "\033[85;0H%4s [%24s] %s", level, function, message);
|
||||
|
||||
if (!extra.empty()) {
|
||||
log.merge_patch(extra);
|
||||
|
@ -168,13 +171,16 @@ static inline void server_log(const char *level, const char *function, int line,
|
|||
for (const auto& el : log.items())
|
||||
{
|
||||
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
snprintf(buf, 1024, "\033[72;0H %s=%s", el.key().c_str(), value.c_str());
|
||||
snprintf(buf, 1024, "\033[85;0H %s=%s", el.key().c_str(), value.c_str());
|
||||
ss << buf;
|
||||
}
|
||||
|
||||
const std::string str = ss.str();
|
||||
printf("\033[72;0H%.*s\n", (int)str.size(), str.data());
|
||||
fflush(stdout);
|
||||
printf("\033[85;0H%.*s\n", (int)str.size(), str.data());
|
||||
fflush(stderr);
|
||||
|
||||
// freopen("/dev/tty", "a", stdout);
|
||||
// freopen("/dev/tty", "a", stderr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue