diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 718bb6029..fcf1af5b0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -167,7 +168,7 @@ int main(int argc, char ** argv) { // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { params.interactive_start = true; - params.antiprompt.push_back("### Instruction:\n\n"); + params.antiprompt.push_back("###"); } // enable interactive mode if reverse prompt or interactive start is specified @@ -438,8 +439,53 @@ int main(int argc, char ** argv) { } n_remain -= line_inp.size(); - } + // check buffer's first 3 chars equal '???' to enter command mode. + if (strncmp(buffer.c_str(), "???", 3) == 0 && buffer.length() > 3) { + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) { + arg = test; + } + if ( j == 3 ) { + cmd = test; + } + } + if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "--top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "--top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "--temp") { + params.temp = std::stof(cmd); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "-r" || arg == "--reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "--keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("invalid argument parsed: %s\n", arg.c_str()); + } + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + } + } + input_noecho = true; // do not echo this again }