From b8e8facb0ee25cafc36cca5edf27c74f21f6edc8 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Thu, 28 Mar 2024 00:05:56 +0800 Subject: [PATCH] add --slot-save-path arg to enable save restore and restrict save location --- examples/server/server.cpp | 45 +++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 227bb3c6b..a86a20ae6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -131,6 +131,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; + std::string slot_save_path; }; struct server_slot { @@ -1628,6 +1629,7 @@ struct server_context { const int64_t t_start = ggml_time_us(); std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; size_t state_size = llama_get_seq_size(ctx, slot->id + 1); std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); @@ -1645,7 +1647,7 @@ struct server_context { } GGML_ASSERT(nwrite <= state_data.size()); - std::ofstream outfile(filename, std::ios::binary); + std::ofstream outfile(filepath, std::ios::binary); outfile.write(reinterpret_cast(state_data.data()), nwrite); outfile.close(); @@ -1678,8 +1680,9 @@ struct server_context { const int64_t t_start = ggml_time_us(); - std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params? - std::ifstream infile(filename, std::ios::binary); + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + std::ifstream infile(filepath, std::ios::binary); if (!infile.is_open()) { send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); break; @@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); + printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, sparams.slots_endpoint = false; } else if (arg == "--metrics") { sparams.metrics_endpoint = true; + } else if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + sparams.slot_save_path += DIRECTORY_SEPARATOR; + } } else if (arg == "--chat-template") { if (++i >= argc) { invalid_param = true; @@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json request_data = json::parse(req.body); int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; + if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + res_error(res, "Invalid filename"); + return; + } + std::string filepath = sparams.slot_save_path + filename; server_task task; task.type = SERVER_TASK_TYPE_SLOT_SAVE; task.data = { { "id_slot", id_slot }, { "filename", filename }, + { "filepath", filepath } }; const int id_task = ctx_server.queue_tasks.post(task); @@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) { } }; - const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json request_data = json::parse(req.body); int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; + if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + res_error(res, "Invalid filename"); + return; + } + std::string filepath = sparams.slot_save_path + filename; server_task task; task.type = SERVER_TASK_TYPE_SLOT_RESTORE; task.data = { { "id_slot", id_slot }, { "filename", filename }, + { "filepath", filepath } }; const int id_task = ctx_server.queue_tasks.post(task); @@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); - svr->Post("/slot/save", handle_slot_save); - svr->Post("/slot/restore", handle_slot_restore); - svr->Post("/slot/erase", handle_slot_erase); + if (!sparams.slot_save_path.empty()) { + // only enable slot endpoints if slot_save_path is set + svr->Post("/slot/save", handle_slot_save); + svr->Post("/slot/restore", handle_slot_restore); + svr->Post("/slot/erase", handle_slot_erase); + } // // Start the server