add some other commands

This commit is contained in:
Xuan Son Nguyen 2024-11-03 22:16:14 +01:00
parent d7a4f3e497
commit 1716e6b25a
3 changed files with 113 additions and 21 deletions

View file

@ -1939,6 +1939,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.simple_io = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
add_opt(common_arg(
{"-nsc", "--no-special-command"},
string_format("disable special commands in conversation mode (default: %s)", params.special_cmds ? "enabled" : "disabled"),
[](common_params & params) {
params.special_cmds = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-ld", "--logdir"}, "LOGDIR",
"path under which to save YAML logs (no logging if unset)",

View file

@ -251,6 +251,7 @@ struct common_params {
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
bool special_cmds = true; // enable special commands in main example
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool multiline_input = false; // reverse the usage of `\`

View file

@ -31,10 +31,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static const std::string CMD_READFILE = "/readfile";
static const std::string CMD_SAVE_SESS = "/savesess";
static const std::string CMD_LOAD_SESS = "/loadsess";
static llama_context ** g_ctx;
static llama_model ** g_model;
static common_sampler ** g_smpl;
@ -45,6 +41,13 @@ static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;
static const char * help_special_cmds = "special commands in conversation mode:\n"
" /readfile FILE read prompt from file\n"
" /savesess FILE save session to file\n"
" /loadsess FILE load session from file\n"
" /regen regenerate the last response\n"
" /dump FILE dump chat content to a file\n";
static void print_usage(int argc, char ** argv) {
(void) argc;
@ -52,6 +55,8 @@ static void print_usage(int argc, char ** argv) {
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]);
LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]);
LOG("\n");
LOG("%s", help_special_cmds);
LOG("\n");
}
static bool file_exists(const std::string & path) {
@ -109,6 +114,21 @@ static void write_logfile(
fclose(logfile);
}
static std::vector<std::string> try_parse_command(std::string text) {
if (text.empty() || text[0] != '/') {
return {};
}
std::vector<std::string> elem = string_split<std::string>(text, ' ');
std::vector<std::string> res;
// filter empty strings
for (const auto & e : elem) {
if (!e.empty()) {
res.push_back(string_strip(e));
}
}
return res;
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
@ -131,7 +151,11 @@ static void sigint_handler(int signo) {
}
#endif
// return the formatted turn to be decoded
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
if (content.empty()) {
return "";
}
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
@ -193,6 +217,7 @@ int main(int argc, char ** argv) {
llama_context * ctx = nullptr;
common_sampler * smpl = nullptr;
std::vector<int> pos_history; // history of positions of chat messages
std::vector<common_chat_msg> chat_msgs;
g_model = &model;
@ -519,6 +544,7 @@ int main(int argc, char ** argv) {
display = params.display_prompt;
std::vector<llama_token> embd;
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
// tokenized antiprompts
std::vector<std::vector<llama_token>> antiprompt_ids;
@ -546,6 +572,8 @@ int main(int argc, char ** argv) {
embd_inp.push_back(decoder_start_token_id);
}
std::stringstream pending_input; // used by "/readfile" command
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
@ -652,7 +680,19 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
common_batch_clear(batch);
for (int j = 0; j < n_eval; j++) {
int idx = i + j;
common_batch_add(
batch,
embd[idx],
n_past + idx,
{0},
idx == (int) embd.size() - 1
);
}
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@ -856,40 +896,81 @@ int main(int argc, char ** argv) {
LOG_DBG("buffer: '%s'\n", buffer.c_str());
// check for special commands
if (buffer.rfind(CMD_READFILE, 0) == 0) {
const std::string filename = string_strip(buffer.substr(CMD_READFILE.length()));
const std::vector<std::string> cmd = params.special_cmds
? try_parse_command(buffer)
: std::vector<std::string>();
if (cmd.size() == 2 && cmd[0] == "/readfile") {
const std::string filename = cmd[1];
LOG_DBG("reading file: '%s'\n", filename.c_str());
std::ifstream text_file(filename);
if (!text_file) {
LOG("failed to open file '%s'\n", filename.c_str());
continue;
}
std::stringstream tmp;
tmp << text_file.rdbuf();
buffer = tmp.str();
LOG("%s\n", buffer.c_str());
} else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) {
const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length()));
pending_input << text_file.rdbuf() << "\n\n";
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
continue;
} else if (cmd.size() == 2 && cmd[0] == "/savesess") {
const std::string filename = cmd[1];
LOG("save session file: '%s'\n", filename.c_str());
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past);
if (res == 0) {
LOG("failed to save session file '%s'\n", filename.c_str());
}
continue;
} else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) {
const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length()));
} else if (cmd.size() == 2 && cmd[0] == "/loadsess") {
const std::string filename = cmd[1];
LOG("load session file: '%s'\n", filename.c_str());
std::vector<llama_token> sess_tokens;
sess_tokens.resize(n_ctx);
size_t n_loaded_tokens;
size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens);
session_tokens.resize(n_ctx);
size_t n_token_count_out;
size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out);
if (res == 0) {
LOG("failed to load session file '%s'\n", filename.c_str());
} else {
n_past = n_loaded_tokens;
LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str());
session_tokens.resize(n_token_count_out);
embd_inp = session_tokens;
n_past = n_token_count_out;
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str());
}
continue;
} else if (cmd.size() == 1 && cmd[0] == "/regen") {
if (pos_history.empty()) {
LOG("no previous assistant message to regenerate\n");
continue;
}
int last_n_past = pos_history.back();
int n_tokens_removed = n_past - last_n_past;
llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1);
n_remain += n_tokens_removed;
is_interacting = false;
// we intentionally do not reset the sampling, so new message will be more diverse
continue;
} else if (cmd.size() == 2 && cmd[0] == "/dump") {
const std::string filename = cmd[1];
std::ofstream dump_file(filename);
if (!dump_file) {
LOG("failed to create file '%s'\n", filename.c_str());
continue;
}
for (const auto & msg : chat_msgs) {
dump_file << msg.role << ":\n" << msg.content << "\n---\n";
}
dump_file.close();
LOG("dumped chat messages to file '%s'\n", filename.c_str());
continue;
} else if (!cmd.empty()) {
LOG("unknown command: %s\n", buffer.c_str());
LOG("%s", help_special_cmds);
continue;
}
if (pending_input.tellp() > 0) {
// concatenate read file and the prompt
pending_input << buffer;
buffer = pending_input.str();
pending_input.clear();
}
const size_t original_size = embd_inp.size();
@ -926,6 +1007,8 @@ int main(int argc, char ** argv) {
output_ss << common_token_to_piece(ctx, token);
}
pos_history.push_back(n_past + embd_inp.size() - original_size);
// reset assistant message
assistant_ss.str("");
@ -971,6 +1054,7 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);