add --slot-save-path arg to enable save restore and restrict save location

This commit is contained in:
Jan Boon 2024-03-28 00:05:56 +08:00
parent 02a184065a
commit b8e8facb0e

View file

@ -131,6 +131,7 @@ struct server_params {
bool slots_endpoint = true; bool slots_endpoint = true;
bool metrics_endpoint = false; bool metrics_endpoint = false;
std::string slot_save_path;
}; };
struct server_slot { struct server_slot {
@ -1628,6 +1629,7 @@ struct server_context {
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"]; std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"];
size_t state_size = llama_get_seq_size(ctx, slot->id + 1); size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
@ -1645,7 +1647,7 @@ struct server_context {
} }
GGML_ASSERT(nwrite <= state_data.size()); GGML_ASSERT(nwrite <= state_data.size());
std::ofstream outfile(filename, std::ios::binary); std::ofstream outfile(filepath, std::ios::binary);
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite); outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
outfile.close(); outfile.close();
@ -1678,8 +1680,9 @@ struct server_context {
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params? std::string filename = task.data["filename"];
std::ifstream infile(filename, std::ios::binary); std::string filepath = task.data["filepath"];
std::ifstream infile(filepath, std::ios::binary);
if (!infile.is_open()) { if (!infile.is_open()) {
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
break; break;
@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --log-disable disables logging to a file.\n"); printf(" --log-disable disables logging to a file.\n");
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
printf("\n"); printf("\n");
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" --override-kv KEY=TYPE:VALUE\n");
@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
sparams.slots_endpoint = false; sparams.slots_endpoint = false;
} else if (arg == "--metrics") { } else if (arg == "--metrics") {
sparams.metrics_endpoint = true; sparams.metrics_endpoint = true;
} else if (arg == "--slot-save-path") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.slot_save_path = argv[i];
// if doesn't end with DIRECTORY_SEPARATOR, add it
if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
sparams.slot_save_path += DIRECTORY_SEPARATOR;
}
} else if (arg == "--chat-template") { } else if (arg == "--chat-template") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"]; int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"]; std::string filename = request_data["filename"];
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
res_error(res, "Invalid filename");
return;
}
std::string filepath = sparams.slot_save_path + filename;
server_task task; server_task task;
task.type = SERVER_TASK_TYPE_SLOT_SAVE; task.type = SERVER_TASK_TYPE_SLOT_SAVE;
task.data = { task.data = {
{ "id_slot", id_slot }, { "id_slot", id_slot },
{ "filename", filename }, { "filename", filename },
{ "filepath", filepath }
}; };
const int id_task = ctx_server.queue_tasks.post(task); const int id_task = ctx_server.queue_tasks.post(task);
@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"]; int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"]; std::string filename = request_data["filename"];
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
res_error(res, "Invalid filename");
return;
}
std::string filepath = sparams.slot_save_path + filename;
server_task task; server_task task;
task.type = SERVER_TASK_TYPE_SLOT_RESTORE; task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
task.data = { task.data = {
{ "id_slot", id_slot }, { "id_slot", id_slot },
{ "filename", filename }, { "filename", filename },
{ "filepath", filepath }
}; };
const int id_task = ctx_server.queue_tasks.post(task); const int id_task = ctx_server.queue_tasks.post(task);
@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize); svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize); svr->Post("/detokenize", handle_detokenize);
svr->Post("/slot/save", handle_slot_save); if (!sparams.slot_save_path.empty()) {
svr->Post("/slot/restore", handle_slot_restore); // only enable slot endpoints if slot_save_path is set
svr->Post("/slot/erase", handle_slot_erase); svr->Post("/slot/save", handle_slot_save);
svr->Post("/slot/restore", handle_slot_restore);
svr->Post("/slot/erase", handle_slot_erase);
}
// //
// Start the server // Start the server