diff --git a/examples/common.cpp b/examples/common.cpp index f149206cd..212974481 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -245,6 +246,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.antiprompt.push_back("### Instruction:\n\n"); } else if (arg == "--color") { params.use_color = true; + } else if (arg == "--multiline") { + params.multiline_mode = true; } else if (arg == "--mlock") { params.use_mlock = true; } else if (arg == "--no-mmap") { @@ -323,6 +326,7 @@ void gpt_print_usage(char * argv_0, const gpt_params & params) { fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); + fprintf(stderr, " --multiline multiline mode (use Ctrl+D on Linux/Mac and Ctrl+Z on Windpws to send input)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); @@ -441,3 +445,51 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) { str = strTo; } #endif + +bool get_input_text(std::string & input_text, bool escape_newline_mode) { + bool another_line = true; + do { + std::string line; +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // input stream is bad or EOF received + if (std::wcin.bad()) { + fprintf(stderr, "%s: error: input stream bad\n", __func__); + return 1; + } + } + if (std::wcin.eof()) { + another_line = false; + std::wcin.clear(); + std::wcin.seekg(0, std::ios::beg); + } + win32_utf8_encode(wline, line); +#else + if (!std::getline(std::cin, line)) { + // input stream is bad or EOF received + if (std::wcin.bad()) { + fprintf(stderr, "%s: error: input stream bad\n", __func__); + return 1; + } + } + if (std::ccin.eof()) { + another_line = false; + std::cin.clear(); + std::cin.seekg(0, std::ios::beg); + } +#endif + if (escape_newline_mode) { + if (line.empty() || line.back() != '\\') { + another_line = false; + } else { + line.pop_back(); // Remove the continue character + } + } + input_text += line; + if (another_line) { + input_text += '\n'; // Append the line to the result + } + } while (another_line); + return true; +} diff --git a/examples/common.h b/examples/common.h index 3a7f103e5..2c4632b41 100644 --- a/examples/common.h +++ b/examples/common.h @@ -56,6 +56,7 @@ struct gpt_params { bool verbose_prompt = false; // print prompt tokens before generation bool clean_interface = false; // hides input prefix & suffix and displays '>' + bool multiline_mode = false; // enables multi-line mode, to send input press CTRL+D on Linux/Max, CTRL+Z on Windows }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); @@ -100,3 +101,5 @@ void set_console_color(console_state & con_st, console_color_t color); void win32_console_init(bool enable_color); void win32_utf8_encode(const std::wstring & wstr, std::string & str); #endif + +bool get_input_text(std::string & input_text, bool escape_newline_mode); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 744c96839..7ea9c3654 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -25,7 +25,8 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. + fflush(stdout); + fflush(stderr); if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; @@ -228,8 +229,18 @@ int main(int argc, char ** argv) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + ); + if (params.multiline_mode) { +#if defined (_WIN32) + fprintf(stderr, " - Press Ctrl+Z and Return (EOF) to return control to LLaMa.\n\n"); +#else + fprintf(stderr, " - Press Ctrl+D (EOF) to return control to LLaMa.\n\n"); +#endif + } + else { + fprintf(stderr, " - Press Return to return control to LLaMa.\n" + " - If you want to submit another line, end your input in '\\'.\n\n"); + } is_interacting = params.interactive_start; } @@ -424,33 +435,13 @@ int main(int argc, char ** argv) { printf("\n> "); } - std::string line; - bool another_line = true; - do { - // TODO: try to revert going to new line after enter (to enable in-line text writing) -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (line.empty() || line.back() != '\\') { - another_line = false; - } else { - line.pop_back(); // Remove the continue character - } - buffer += line; - if (another_line || !antiprompt.is_stop_prompt) { - buffer += '\n'; // Append the line to the result - } - } while (another_line); + if (!get_input_text(buffer, !params.multiline_mode)) { + // input stream is bad + return 1; + } + if (!antiprompt.is_stop_prompt) { + buffer += "\n"; + } // done taking input, reset color set_console_color(con_st, CONSOLE_COLOR_DEFAULT);