From 5cc9085353b788c04ffd84860dd6b93bc54393c0 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sun, 7 May 2023 02:39:10 -0700 Subject: [PATCH] Works with all characters and control codes + Windows console fixes --- examples/common.cpp | 240 +++++++++++++++++++++++++++++--------------- examples/common.h | 8 +- 2 files changed, 168 insertions(+), 80 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index f6b2d6b13..5eeab0cb1 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,14 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) #define WIN32_LEAN_AND_MEAN #define NOMINMAX #include #include #include #else +#include #include +#include #endif int32_t get_num_physical_cores() { @@ -473,45 +475,27 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void console_set_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - void console_init(console_state & con_st) { -#if defined (_WIN32) +#if defined(_WIN32) // Windows-specific console initialization - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (con_st.use_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); @@ -528,46 +512,49 @@ void console_init(console_state & con_st) { new_termios.c_cc[VMIN] = 1; new_termios.c_cc[VTIME] = 0; tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } #endif setlocale(LC_ALL, ""); } void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + #if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } // Restore the terminal settings on POSIX systems tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); #endif - - // Reset console color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } -#if defined (_WIN32) -int puts_get_width(_In_z_ CONST CHAR* lpBuffer) { - DWORD nNumberOfCharsToWrite = strlen(lpBuffer); - - HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { - // Make a guess - return 1; +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } - COORD initialPosition = bufferInfo.dwCursorPosition; - - DWORD written = 0; - WriteConsole(hConsole, lpBuffer, nNumberOfCharsToWrite, &written, nullptr); - - CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; - GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); - - int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; - if (newBufferInfo.dwCursorPosition.Y > initialPosition.Y) { - width += (newBufferInfo.dwSize.X - initialPosition.X); - } - - return width; } -#endif char32_t getchar32() { wchar_t wc = getwchar(); @@ -590,6 +577,102 @@ char32_t getchar32() { return static_cast(wc); } +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + void append_utf8(char32_t ch, std::string & out) { if (ch <= 0x7F) { out.push_back(static_cast(ch)); @@ -627,6 +710,9 @@ void pop_back_utf8_char(std::string & line) { bool console_readline(console_state & con_st, std::string & line) { console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } line.clear(); std::vector widths; @@ -635,7 +721,7 @@ bool console_readline(console_state & con_st, std::string & line) { char32_t input_char; while (true) { - fflush(stdout); // Ensure all output is displayed before waiting for input + fflush(con_st.out); // Ensure all output is displayed before waiting for input input_char = getchar32(); if (input_char == '\r' || input_char == '\n') { @@ -649,8 +735,7 @@ bool console_readline(console_state & con_st, std::string & line) { if (is_special_char) { console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - putchar('\b'); - putchar(line.back()); + replace_last(con_st, line.back()); is_special_char = false; } @@ -670,50 +755,47 @@ bool console_readline(console_state & con_st, std::string & line) { do { count = widths.back(); widths.pop_back(); - // Move cursor back, print spaces, and move cursor back again + // Move cursor back, print space, and move cursor back again for (int i = 0; i < count; i++) { - fputs("\b \b", stdout); + replace_last(con_st, ' '); + pop_cursor(con_st); } pop_back_utf8_char(line); } while (count == 0 && !widths.empty()); } - } else if (input_char < 32) { - // Ignore control characters } else { int offset = line.length(); append_utf8(input_char, line); -#if defined (_WIN32) - int width = puts_get_width(line.c_str() + offset); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } widths.push_back(width); -#else - fputs(line.c_str() + offset, stdout); - widths.push_back(wcwidth(input_char)); -#endif } if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { console_set_color(con_st, CONSOLE_COLOR_PROMPT); - putchar('\b'); - putchar(line.back()); + replace_last(con_st, line.back()); is_special_char = true; } } bool has_more = con_st.multiline_input; if (is_special_char) { - fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + replace_last(con_st, ' '); + pop_cursor(con_st); char last = line.back(); line.pop_back(); if (last == '\\') { line += '\n'; - putchar('\n'); + fputc('\n', con_st.out); has_more = !has_more; } else { - // llama will just eat the single space + // llama will just eat the single space, it won't act as a space if (line.length() == 1 && line.back() == ' ') { line.clear(); - putchar('\b'); + pop_cursor(con_st); } has_more = false; } @@ -722,10 +804,10 @@ bool console_readline(console_state & con_st, std::string & line) { has_more = false; } else { line += '\n'; - putchar('\n'); + fputc('\n', con_st.out); } } - fflush(stdout); + fflush(con_st.out); return has_more; } diff --git a/examples/common.h b/examples/common.h index 0950fc7c3..43f1cc9ef 100644 --- a/examples/common.h +++ b/examples/common.h @@ -11,6 +11,7 @@ #include #if !defined (_WIN32) +#include #include #endif @@ -112,7 +113,12 @@ struct console_state { bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; -#if !defined (_WIN32) + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; termios prev_state; #endif };