add some other commands
This commit is contained in:
parent
d7a4f3e497
commit
1716e6b25a
3 changed files with 113 additions and 21 deletions
|
@ -1939,6 +1939,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.simple_io = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
|
||||
add_opt(common_arg(
|
||||
{"-nsc", "--no-special-command"},
|
||||
string_format("disable special commands in conversation mode (default: %s)", params.special_cmds ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.special_cmds = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
add_opt(common_arg(
|
||||
{"-ld", "--logdir"}, "LOGDIR",
|
||||
"path under which to save YAML logs (no logging if unset)",
|
||||
|
|
|
@ -251,6 +251,7 @@ struct common_params {
|
|||
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
|
||||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||
bool special_cmds = true; // enable special commands in main example
|
||||
|
||||
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
|
||||
bool multiline_input = false; // reverse the usage of `\`
|
||||
|
|
|
@ -31,10 +31,6 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static const std::string CMD_READFILE = "/readfile";
|
||||
static const std::string CMD_SAVE_SESS = "/savesess";
|
||||
static const std::string CMD_LOAD_SESS = "/loadsess";
|
||||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static common_sampler ** g_smpl;
|
||||
|
@ -45,6 +41,13 @@ static std::vector<llama_token> * g_output_tokens;
|
|||
static bool is_interacting = false;
|
||||
static bool need_insert_eot = false;
|
||||
|
||||
static const char * help_special_cmds = "special commands in conversation mode:\n"
|
||||
" /readfile FILE read prompt from file\n"
|
||||
" /savesess FILE save session to file\n"
|
||||
" /loadsess FILE load session from file\n"
|
||||
" /regen regenerate the last response\n"
|
||||
" /dump FILE dump chat content to a file\n";
|
||||
|
||||
static void print_usage(int argc, char ** argv) {
|
||||
(void) argc;
|
||||
|
||||
|
@ -52,6 +55,8 @@ static void print_usage(int argc, char ** argv) {
|
|||
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]);
|
||||
LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]);
|
||||
LOG("\n");
|
||||
LOG("%s", help_special_cmds);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static bool file_exists(const std::string & path) {
|
||||
|
@ -109,6 +114,21 @@ static void write_logfile(
|
|||
fclose(logfile);
|
||||
}
|
||||
|
||||
static std::vector<std::string> try_parse_command(std::string text) {
|
||||
if (text.empty() || text[0] != '/') {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> elem = string_split<std::string>(text, ' ');
|
||||
std::vector<std::string> res;
|
||||
// filter empty strings
|
||||
for (const auto & e : elem) {
|
||||
if (!e.empty()) {
|
||||
res.push_back(string_strip(e));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
|
@ -131,7 +151,11 @@ static void sigint_handler(int signo) {
|
|||
}
|
||||
#endif
|
||||
|
||||
// return the formatted turn to be decoded
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
||||
if (content.empty()) {
|
||||
return "";
|
||||
}
|
||||
common_chat_msg new_msg{role, content};
|
||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
|
@ -193,6 +217,7 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx = nullptr;
|
||||
common_sampler * smpl = nullptr;
|
||||
|
||||
std::vector<int> pos_history; // history of positions of chat messages
|
||||
std::vector<common_chat_msg> chat_msgs;
|
||||
|
||||
g_model = &model;
|
||||
|
@ -519,6 +544,7 @@ int main(int argc, char ** argv) {
|
|||
display = params.display_prompt;
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
|
||||
// tokenized antiprompts
|
||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||
|
@ -546,6 +572,8 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.push_back(decoder_start_token_id);
|
||||
}
|
||||
|
||||
std::stringstream pending_input; // used by "/readfile" command
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
|
@ -652,7 +680,19 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||
common_batch_clear(batch);
|
||||
for (int j = 0; j < n_eval; j++) {
|
||||
int idx = i + j;
|
||||
common_batch_add(
|
||||
batch,
|
||||
embd[idx],
|
||||
n_past + idx,
|
||||
{0},
|
||||
idx == (int) embd.size() - 1
|
||||
);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -856,40 +896,81 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("buffer: '%s'\n", buffer.c_str());
|
||||
|
||||
// check for special commands
|
||||
if (buffer.rfind(CMD_READFILE, 0) == 0) {
|
||||
const std::string filename = string_strip(buffer.substr(CMD_READFILE.length()));
|
||||
const std::vector<std::string> cmd = params.special_cmds
|
||||
? try_parse_command(buffer)
|
||||
: std::vector<std::string>();
|
||||
|
||||
if (cmd.size() == 2 && cmd[0] == "/readfile") {
|
||||
const std::string filename = cmd[1];
|
||||
LOG_DBG("reading file: '%s'\n", filename.c_str());
|
||||
std::ifstream text_file(filename);
|
||||
if (!text_file) {
|
||||
LOG("failed to open file '%s'\n", filename.c_str());
|
||||
continue;
|
||||
}
|
||||
std::stringstream tmp;
|
||||
tmp << text_file.rdbuf();
|
||||
buffer = tmp.str();
|
||||
LOG("%s\n", buffer.c_str());
|
||||
} else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) {
|
||||
const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length()));
|
||||
pending_input << text_file.rdbuf() << "\n\n";
|
||||
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
|
||||
continue;
|
||||
} else if (cmd.size() == 2 && cmd[0] == "/savesess") {
|
||||
const std::string filename = cmd[1];
|
||||
LOG("save session file: '%s'\n", filename.c_str());
|
||||
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past);
|
||||
if (res == 0) {
|
||||
LOG("failed to save session file '%s'\n", filename.c_str());
|
||||
}
|
||||
continue;
|
||||
} else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) {
|
||||
const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length()));
|
||||
} else if (cmd.size() == 2 && cmd[0] == "/loadsess") {
|
||||
const std::string filename = cmd[1];
|
||||
LOG("load session file: '%s'\n", filename.c_str());
|
||||
std::vector<llama_token> sess_tokens;
|
||||
sess_tokens.resize(n_ctx);
|
||||
size_t n_loaded_tokens;
|
||||
size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens);
|
||||
session_tokens.resize(n_ctx);
|
||||
size_t n_token_count_out;
|
||||
size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out);
|
||||
if (res == 0) {
|
||||
LOG("failed to load session file '%s'\n", filename.c_str());
|
||||
} else {
|
||||
n_past = n_loaded_tokens;
|
||||
LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str());
|
||||
session_tokens.resize(n_token_count_out);
|
||||
embd_inp = session_tokens;
|
||||
n_past = n_token_count_out;
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str());
|
||||
}
|
||||
continue;
|
||||
} else if (cmd.size() == 1 && cmd[0] == "/regen") {
|
||||
if (pos_history.empty()) {
|
||||
LOG("no previous assistant message to regenerate\n");
|
||||
continue;
|
||||
}
|
||||
int last_n_past = pos_history.back();
|
||||
int n_tokens_removed = n_past - last_n_past;
|
||||
llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1);
|
||||
n_remain += n_tokens_removed;
|
||||
is_interacting = false;
|
||||
// we intentionally do not reset the sampling, so new message will be more diverse
|
||||
continue;
|
||||
} else if (cmd.size() == 2 && cmd[0] == "/dump") {
|
||||
const std::string filename = cmd[1];
|
||||
std::ofstream dump_file(filename);
|
||||
if (!dump_file) {
|
||||
LOG("failed to create file '%s'\n", filename.c_str());
|
||||
continue;
|
||||
}
|
||||
for (const auto & msg : chat_msgs) {
|
||||
dump_file << msg.role << ":\n" << msg.content << "\n---\n";
|
||||
}
|
||||
dump_file.close();
|
||||
LOG("dumped chat messages to file '%s'\n", filename.c_str());
|
||||
continue;
|
||||
} else if (!cmd.empty()) {
|
||||
LOG("unknown command: %s\n", buffer.c_str());
|
||||
LOG("%s", help_special_cmds);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pending_input.tellp() > 0) {
|
||||
// concatenate read file and the prompt
|
||||
pending_input << buffer;
|
||||
buffer = pending_input.str();
|
||||
pending_input.clear();
|
||||
}
|
||||
|
||||
const size_t original_size = embd_inp.size();
|
||||
|
@ -926,6 +1007,8 @@ int main(int argc, char ** argv) {
|
|||
output_ss << common_token_to_piece(ctx, token);
|
||||
}
|
||||
|
||||
pos_history.push_back(n_past + embd_inp.size() - original_size);
|
||||
|
||||
// reset assistant message
|
||||
assistant_ss.str("");
|
||||
|
||||
|
@ -971,6 +1054,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue