diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 374ed47ad..b812e4c96 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -31,6 +31,10 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const std::string CMD_READFILE = "/readfile"; +static const std::string CMD_SAVE_SESS = "/savesess"; +static const std::string CMD_LOAD_SESS = "/loadsess"; + static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -851,6 +855,43 @@ int main(int argc, char ** argv) { LOG_DBG("buffer: '%s'\n", buffer.c_str()); + // check for special commands + if (buffer.rfind(CMD_READFILE, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_READFILE.length())); + LOG_DBG("reading file: '%s'\n", filename.c_str()); + std::ifstream text_file(filename); + if (!text_file) { + LOG("failed to open file '%s'\n", filename.c_str()); + continue; + } + std::stringstream tmp; + tmp << text_file.rdbuf(); + buffer = tmp.str(); + LOG("%s\n", buffer.c_str()); + } else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length())); + LOG("save session file: '%s'\n", filename.c_str()); + size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past); + if (res == 0) { + LOG("failed to save session file '%s'\n", filename.c_str()); + } + continue; + } else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length())); + LOG("load session file: '%s'\n", filename.c_str()); + std::vector sess_tokens; + sess_tokens.resize(n_ctx); + size_t n_loaded_tokens; + size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens); + if (res == 0) { + LOG("failed to load session file '%s'\n", filename.c_str()); + } else { + n_past = n_loaded_tokens; + LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str()); + } + continue; + } + const size_t original_size = embd_inp.size(); if (params.escape) {