add --slot-save-path arg to enable save restore and restrict save location
This commit is contained in:
parent
02a184065a
commit
b8e8facb0e
1 changed files with 37 additions and 8 deletions
|
@ -131,6 +131,7 @@ struct server_params {
|
|||
|
||||
bool slots_endpoint = true;
|
||||
bool metrics_endpoint = false;
|
||||
std::string slot_save_path;
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
|
@ -1628,6 +1629,7 @@ struct server_context {
|
|||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data["filename"];
|
||||
std::string filepath = task.data["filepath"];
|
||||
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
|
||||
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
||||
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
|
||||
|
@ -1645,7 +1647,7 @@ struct server_context {
|
|||
}
|
||||
GGML_ASSERT(nwrite <= state_data.size());
|
||||
|
||||
std::ofstream outfile(filename, std::ios::binary);
|
||||
std::ofstream outfile(filepath, std::ios::binary);
|
||||
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
|
||||
outfile.close();
|
||||
|
||||
|
@ -1678,8 +1680,9 @@ struct server_context {
|
|||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params?
|
||||
std::ifstream infile(filename, std::ios::binary);
|
||||
std::string filename = task.data["filename"];
|
||||
std::string filepath = task.data["filepath"];
|
||||
std::ifstream infile(filepath, std::ios::binary);
|
||||
if (!infile.is_open()) {
|
||||
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
|
@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||
printf(" --log-disable disables logging to a file.\n");
|
||||
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
||||
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
|
||||
printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
|
||||
printf("\n");
|
||||
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
|
@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
sparams.slots_endpoint = false;
|
||||
} else if (arg == "--metrics") {
|
||||
sparams.metrics_endpoint = true;
|
||||
} else if (arg == "--slot-save-path") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.slot_save_path = argv[i];
|
||||
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
||||
if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
||||
sparams.slot_save_path += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
} else if (arg == "--chat-template") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) {
|
|||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json request_data = json::parse(req.body);
|
||||
int id_slot = request_data["id_slot"];
|
||||
std::string filename = request_data["filename"];
|
||||
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
|
||||
res_error(res, "Invalid filename");
|
||||
return;
|
||||
}
|
||||
std::string filepath = sparams.slot_save_path + filename;
|
||||
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
|
||||
task.data = {
|
||||
{ "id_slot", id_slot },
|
||||
{ "filename", filename },
|
||||
{ "filepath", filepath }
|
||||
};
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
|
@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json request_data = json::parse(req.body);
|
||||
int id_slot = request_data["id_slot"];
|
||||
std::string filename = request_data["filename"];
|
||||
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
|
||||
res_error(res, "Invalid filename");
|
||||
return;
|
||||
}
|
||||
std::string filepath = sparams.slot_save_path + filename;
|
||||
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
|
||||
task.data = {
|
||||
{ "id_slot", id_slot },
|
||||
{ "filename", filename },
|
||||
{ "filepath", filepath }
|
||||
};
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
|
@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) {
|
|||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/slot/save", handle_slot_save);
|
||||
svr->Post("/slot/restore", handle_slot_restore);
|
||||
svr->Post("/slot/erase", handle_slot_erase);
|
||||
if (!sparams.slot_save_path.empty()) {
|
||||
// only enable slot endpoints if slot_save_path is set
|
||||
svr->Post("/slot/save", handle_slot_save);
|
||||
svr->Post("/slot/restore", handle_slot_restore);
|
||||
svr->Post("/slot/erase", handle_slot_erase);
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue