diff --git a/examples/common.cpp b/examples/common.cpp index 97eded6ec..cd37f96f4 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int const wchar_t * lpWideCharStr, int cchWideChar, char * lpMultiByteStr, int cbMultiByte, const char * lpDefaultChar, bool * lpUsedDefaultChar); +#define ENABLE_LINE_INPUT 0x0002 +#define ENABLE_ECHO_INPUT 0x0004 #define CP_UTF8 65001 +#define CONSOLE_CHAR_TYPE wchar_t +#define CONSOLE_GET_CHAR() getwchar() +#define CONSOLE_EOF WEOF +#else +#include +#define CONSOLE_CHAR_TYPE char +#define CONSOLE_GET_CHAR() getchar() +#define CONSOLE_EOF EOF #endif int32_t get_num_physical_cores() { @@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.embedding = true; } else if (arg == "--interactive-first") { params.interactive_first = true; + } else if (arg == "--author-mode") { + params.author_mode = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; } else if (arg == "--color") { @@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { } /* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { +void console_set_color(console_state & con_st, console_color_t color) { if (con_st.use_color && con_st.color != color) { switch(color) { case CONSOLE_COLOR_DEFAULT: @@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) { } } +void console_init(console_state & con_st) { #if defined (_WIN32) -void win32_console_init(bool enable_color) { + // Windows-specific console initialization unsigned long dwMode = 0; void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { @@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) { } if (hConOut) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { + if (con_st.use_color && !(dwMode & 0x4)) { SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) } // Set console output codepage to UTF8 @@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) { if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); +#endif } +void console_cleanup(console_state & con_st) { +#if !defined(_WIN32) + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif + + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); +} + +// Helper function to remove the last UTF-8 character from a string +void remove_last_utf8_char(std::string & line) { + if (line.empty()) return; + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + } + line.erase(pos); +} + +#if defined (_WIN32) // Convert a wide Unicode string to an UTF8 string void win32_utf8_encode(const std::wstring & wstr, std::string & str) { int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); @@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) { str = strTo; } #endif + +bool console_readline(console_state & con_st, std::string & line) { + line.clear(); + bool is_special_char = false; + bool end_of_stream = false; + + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + + CONSOLE_CHAR_TYPE input_char; + while (true) { + fflush(stdout); // Ensure all output is displayed before waiting for input + input_char = CONSOLE_GET_CHAR(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + putchar('\b'); + putchar(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR(); + if (code == '[') { + // Discard the rest of the escape sequence + while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!line.empty()) { + fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + remove_last_utf8_char(line); + } + } else if (input_char < 32) { + // Ignore control characters + } else { +#if defined(_WIN32) + std::string utf8_char; + win32_utf8_encode(std::wstring(1, input_char), utf8_char); + line += utf8_char; + fputs(utf8_char.c_str(), stdout); +#else + line += input_char; + putchar(input_char); +#endif + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + putchar('\b'); + putchar(line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.author_mode; + if (is_special_char) { + fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + putchar('\n'); + has_more = !has_more; + } else { + // llama doesn't seem to process a single space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + putchar('\b'); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + putchar('\n'); + } + } + + fflush(stdout); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516f..cb1e384e2 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,10 @@ #include #include +#if !defined (_WIN32) +#include +#endif + // // CLI argument parsing // @@ -56,6 +60,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool author_mode = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +109,15 @@ enum console_color_t { }; struct console_state { + bool author_mode = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; +#if !defined (_WIN32) + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 43dca8eb5..5124b8aa9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.author_mode = params.author_mode; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.author_mode) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; }