add --slot-save-path arg to enable save restore and restrict save location
This commit is contained in:
parent
02a184065a
commit
b8e8facb0e
1 changed files with 37 additions and 8 deletions
|
@ -131,6 +131,7 @@ struct server_params {
|
||||||
|
|
||||||
bool slots_endpoint = true;
|
bool slots_endpoint = true;
|
||||||
bool metrics_endpoint = false;
|
bool metrics_endpoint = false;
|
||||||
|
std::string slot_save_path;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
|
@ -1628,6 +1629,7 @@ struct server_context {
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
std::string filename = task.data["filename"];
|
std::string filename = task.data["filename"];
|
||||||
|
std::string filepath = task.data["filepath"];
|
||||||
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
|
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
|
||||||
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
||||||
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
|
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
|
||||||
|
@ -1645,7 +1647,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
GGML_ASSERT(nwrite <= state_data.size());
|
GGML_ASSERT(nwrite <= state_data.size());
|
||||||
|
|
||||||
std::ofstream outfile(filename, std::ios::binary);
|
std::ofstream outfile(filepath, std::ios::binary);
|
||||||
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
|
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
|
||||||
outfile.close();
|
outfile.close();
|
||||||
|
|
||||||
|
@ -1678,8 +1680,9 @@ struct server_context {
|
||||||
|
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params?
|
std::string filename = task.data["filename"];
|
||||||
std::ifstream infile(filename, std::ios::binary);
|
std::string filepath = task.data["filepath"];
|
||||||
|
std::ifstream infile(filepath, std::ios::binary);
|
||||||
if (!infile.is_open()) {
|
if (!infile.is_open()) {
|
||||||
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
|
@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||||
printf(" --log-disable disables logging to a file.\n");
|
printf(" --log-disable disables logging to a file.\n");
|
||||||
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
||||||
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
|
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
|
||||||
|
printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
|
@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
sparams.slots_endpoint = false;
|
sparams.slots_endpoint = false;
|
||||||
} else if (arg == "--metrics") {
|
} else if (arg == "--metrics") {
|
||||||
sparams.metrics_endpoint = true;
|
sparams.metrics_endpoint = true;
|
||||||
|
} else if (arg == "--slot-save-path") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.slot_save_path = argv[i];
|
||||||
|
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
||||||
|
if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
||||||
|
sparams.slot_save_path += DIRECTORY_SEPARATOR;
|
||||||
|
}
|
||||||
} else if (arg == "--chat-template") {
|
} else if (arg == "--chat-template") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) {
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json request_data = json::parse(req.body);
|
json request_data = json::parse(req.body);
|
||||||
int id_slot = request_data["id_slot"];
|
int id_slot = request_data["id_slot"];
|
||||||
std::string filename = request_data["filename"];
|
std::string filename = request_data["filename"];
|
||||||
|
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
|
||||||
|
res_error(res, "Invalid filename");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string filepath = sparams.slot_save_path + filename;
|
||||||
|
|
||||||
server_task task;
|
server_task task;
|
||||||
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
|
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
|
||||||
task.data = {
|
task.data = {
|
||||||
{ "id_slot", id_slot },
|
{ "id_slot", id_slot },
|
||||||
{ "filename", filename },
|
{ "filename", filename },
|
||||||
|
{ "filepath", filepath }
|
||||||
};
|
};
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
|
@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json request_data = json::parse(req.body);
|
json request_data = json::parse(req.body);
|
||||||
int id_slot = request_data["id_slot"];
|
int id_slot = request_data["id_slot"];
|
||||||
std::string filename = request_data["filename"];
|
std::string filename = request_data["filename"];
|
||||||
|
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
|
||||||
|
res_error(res, "Invalid filename");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string filepath = sparams.slot_save_path + filename;
|
||||||
|
|
||||||
server_task task;
|
server_task task;
|
||||||
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
|
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
|
||||||
task.data = {
|
task.data = {
|
||||||
{ "id_slot", id_slot },
|
{ "id_slot", id_slot },
|
||||||
{ "filename", filename },
|
{ "filename", filename },
|
||||||
|
{ "filepath", filepath }
|
||||||
};
|
};
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.post(task);
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
|
@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) {
|
||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings);
|
||||||
svr->Post("/tokenize", handle_tokenize);
|
svr->Post("/tokenize", handle_tokenize);
|
||||||
svr->Post("/detokenize", handle_detokenize);
|
svr->Post("/detokenize", handle_detokenize);
|
||||||
svr->Post("/slot/save", handle_slot_save);
|
if (!sparams.slot_save_path.empty()) {
|
||||||
svr->Post("/slot/restore", handle_slot_restore);
|
// only enable slot endpoints if slot_save_path is set
|
||||||
svr->Post("/slot/erase", handle_slot_erase);
|
svr->Post("/slot/save", handle_slot_save);
|
||||||
|
svr->Post("/slot/restore", handle_slot_restore);
|
||||||
|
svr->Post("/slot/erase", handle_slot_erase);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Start the server
|
// Start the server
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue