Works with all characters and control codes + Windows console fixes

This commit is contained in:
Danny Daemonic 2023-05-07 02:39:10 -07:00
parent 534c89e766
commit 5cc9085353
2 changed files with 168 additions and 80 deletions

View file

@ -14,14 +14,16 @@
#include <sys/sysctl.h> #include <sys/sysctl.h>
#endif #endif
#if defined (_WIN32) #if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
#define NOMINMAX #define NOMINMAX
#include <windows.h> #include <windows.h>
#include <fcntl.h> #include <fcntl.h>
#include <io.h> #include <io.h>
#else #else
#include <sys/ioctl.h>
#include <unistd.h> #include <unistd.h>
#include <wchar.h>
#endif #endif
int32_t get_num_physical_cores() { int32_t get_num_physical_cores() {
@ -473,45 +475,27 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
return lctx; return lctx;
} }
/* Keep track of current color of output, and emit ANSI code if it changes. */
void console_set_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
switch(color) {
case CONSOLE_COLOR_DEFAULT:
printf(ANSI_COLOR_RESET);
break;
case CONSOLE_COLOR_PROMPT:
printf(ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
}
}
void console_init(console_state & con_st) { void console_init(console_state & con_st) {
#if defined (_WIN32) #if defined(_WIN32)
// Windows-specific console initialization // Windows-specific console initialization
unsigned long dwMode = 0; DWORD dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE);
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) {
hConOut = 0; con_st.hConsole = NULL;
} }
} }
if (hConOut) { if (con_st.hConsole) {
// Enable ANSI colors on Windows 10+ // Enable ANSI colors on Windows 10+
if (con_st.use_color && !(dwMode & 0x4)) { if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
} }
// Set console output codepage to UTF8 // Set console output codepage to UTF8
SetConsoleOutputCP(CP_UTF8); SetConsoleOutputCP(CP_UTF8);
} }
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF16 // Set console input codepage to UTF16
_setmode(_fileno(stdin), _O_WTEXT); _setmode(_fileno(stdin), _O_WTEXT);
@ -528,46 +512,49 @@ void console_init(console_state & con_st) {
new_termios.c_cc[VMIN] = 1; new_termios.c_cc[VMIN] = 1;
new_termios.c_cc[VTIME] = 0; new_termios.c_cc[VTIME] = 0;
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
con_st.tty = fopen("/dev/tty", "w+");
if (con_st.tty != nullptr) {
con_st.out = con_st.tty;
}
#endif #endif
setlocale(LC_ALL, ""); setlocale(LC_ALL, "");
} }
void console_cleanup(console_state & con_st) { void console_cleanup(console_state & con_st) {
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
#if !defined(_WIN32) #if !defined(_WIN32)
if (con_st.tty != nullptr) {
con_st.out = stdout;
fclose(con_st.tty);
con_st.tty = nullptr;
}
// Restore the terminal settings on POSIX systems // Restore the terminal settings on POSIX systems
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
#endif #endif
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
} }
#if defined (_WIN32) /* Keep track of current color of output, and emit ANSI code if it changes. */
int puts_get_width(_In_z_ CONST CHAR* lpBuffer) { void console_set_color(console_state & con_st, console_color_t color) {
DWORD nNumberOfCharsToWrite = strlen(lpBuffer); if (con_st.use_color && con_st.color != color) {
fflush(stdout);
HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); switch(color) {
CONSOLE_SCREEN_BUFFER_INFO bufferInfo; case CONSOLE_COLOR_DEFAULT:
if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { fprintf(con_st.out, ANSI_COLOR_RESET);
// Make a guess break;
return 1; case CONSOLE_COLOR_PROMPT:
fprintf(con_st.out, ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
fflush(con_st.out);
} }
COORD initialPosition = bufferInfo.dwCursorPosition;
DWORD written = 0;
WriteConsole(hConsole, lpBuffer, nNumberOfCharsToWrite, &written, nullptr);
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
if (newBufferInfo.dwCursorPosition.Y > initialPosition.Y) {
width += (newBufferInfo.dwSize.X - initialPosition.X);
}
return width;
} }
#endif
char32_t getchar32() { char32_t getchar32() {
wchar_t wc = getwchar(); wchar_t wc = getwchar();
@ -590,6 +577,102 @@ char32_t getchar32() {
return static_cast<char32_t>(wc); return static_cast<char32_t>(wc);
} }
void pop_cursor(console_state & con_st) {
#if defined(_WIN32)
if (con_st.hConsole != NULL) {
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo);
COORD newCursorPosition = bufferInfo.dwCursorPosition;
if (newCursorPosition.X == 0) {
newCursorPosition.X = bufferInfo.dwSize.X - 1;
newCursorPosition.Y -= 1;
} else {
newCursorPosition.X -= 1;
}
SetConsoleCursorPosition(con_st.hConsole, newCursorPosition);
return;
}
#endif
putc('\b', con_st.out);
}
int estimateWidth(char32_t codepoint) {
#if defined(_WIN32)
return 1;
#else
return wcwidth(codepoint);
#endif
}
int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) {
#if defined(_WIN32)
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) {
// go with the default
return expectedWidth;
}
COORD initialPosition = bufferInfo.dwCursorPosition;
DWORD nNumberOfChars = length;
WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
// Figure out our real position if we're in the last column
if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
DWORD nNumberOfChars;
WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL);
GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
}
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
if (width < 0) {
width += newBufferInfo.dwSize.X;
}
return width;
#else
// we can trust expectedWidth if we've got one
if (expectedWidth >= 0 || con_st.tty == nullptr) {
fwrite(utf8_codepoint, length, 1, con_st.out);
return expectedWidth;
}
fputs("\033[6n", con_st.tty); // Query cursor position
int x1, x2, y1, y2;
int results = 0;
results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1);
fwrite(utf8_codepoint, length, 1, con_st.tty);
fputs("\033[6n", con_st.tty); // Query cursor position
results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2);
if (results != 4) {
return expectedWidth;
}
int width = x2 - x1;
if (width < 0) {
// Calculate the width considering text wrapping
struct winsize w;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
width += w.ws_col;
}
return width;
#endif
}
void replace_last(console_state & con_st, char ch) {
#if defined(_WIN32)
pop_cursor(con_st);
put_codepoint(con_st, &ch, 1, 1);
#else
fprintf(con_st.out, "\b%c", ch);
#endif
}
void append_utf8(char32_t ch, std::string & out) { void append_utf8(char32_t ch, std::string & out) {
if (ch <= 0x7F) { if (ch <= 0x7F) {
out.push_back(static_cast<unsigned char>(ch)); out.push_back(static_cast<unsigned char>(ch));
@ -627,6 +710,9 @@ void pop_back_utf8_char(std::string & line) {
bool console_readline(console_state & con_st, std::string & line) { bool console_readline(console_state & con_st, std::string & line) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (con_st.out != stdout) {
fflush(stdout);
}
line.clear(); line.clear();
std::vector<int> widths; std::vector<int> widths;
@ -635,7 +721,7 @@ bool console_readline(console_state & con_st, std::string & line) {
char32_t input_char; char32_t input_char;
while (true) { while (true) {
fflush(stdout); // Ensure all output is displayed before waiting for input fflush(con_st.out); // Ensure all output is displayed before waiting for input
input_char = getchar32(); input_char = getchar32();
if (input_char == '\r' || input_char == '\n') { if (input_char == '\r' || input_char == '\n') {
@ -649,8 +735,7 @@ bool console_readline(console_state & con_st, std::string & line) {
if (is_special_char) { if (is_special_char) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
putchar('\b'); replace_last(con_st, line.back());
putchar(line.back());
is_special_char = false; is_special_char = false;
} }
@ -670,50 +755,47 @@ bool console_readline(console_state & con_st, std::string & line) {
do { do {
count = widths.back(); count = widths.back();
widths.pop_back(); widths.pop_back();
// Move cursor back, print spaces, and move cursor back again // Move cursor back, print space, and move cursor back again
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
fputs("\b \b", stdout); replace_last(con_st, ' ');
pop_cursor(con_st);
} }
pop_back_utf8_char(line); pop_back_utf8_char(line);
} while (count == 0 && !widths.empty()); } while (count == 0 && !widths.empty());
} }
} else if (input_char < 32) {
// Ignore control characters
} else { } else {
int offset = line.length(); int offset = line.length();
append_utf8(input_char, line); append_utf8(input_char, line);
#if defined (_WIN32) int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
int width = puts_get_width(line.c_str() + offset); if (width < 0) {
width = 0;
}
widths.push_back(width); widths.push_back(width);
#else
fputs(line.c_str() + offset, stdout);
widths.push_back(wcwidth(input_char));
#endif
} }
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT); console_set_color(con_st, CONSOLE_COLOR_PROMPT);
putchar('\b'); replace_last(con_st, line.back());
putchar(line.back());
is_special_char = true; is_special_char = true;
} }
} }
bool has_more = con_st.multiline_input; bool has_more = con_st.multiline_input;
if (is_special_char) { if (is_special_char) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again replace_last(con_st, ' ');
pop_cursor(con_st);
char last = line.back(); char last = line.back();
line.pop_back(); line.pop_back();
if (last == '\\') { if (last == '\\') {
line += '\n'; line += '\n';
putchar('\n'); fputc('\n', con_st.out);
has_more = !has_more; has_more = !has_more;
} else { } else {
// llama will just eat the single space // llama will just eat the single space, it won't act as a space
if (line.length() == 1 && line.back() == ' ') { if (line.length() == 1 && line.back() == ' ') {
line.clear(); line.clear();
putchar('\b'); pop_cursor(con_st);
} }
has_more = false; has_more = false;
} }
@ -722,10 +804,10 @@ bool console_readline(console_state & con_st, std::string & line) {
has_more = false; has_more = false;
} else { } else {
line += '\n'; line += '\n';
putchar('\n'); fputc('\n', con_st.out);
} }
} }
fflush(stdout); fflush(con_st.out);
return has_more; return has_more;
} }

View file

@ -11,6 +11,7 @@
#include <unordered_map> #include <unordered_map>
#if !defined (_WIN32) #if !defined (_WIN32)
#include <stdio.h>
#include <termios.h> #include <termios.h>
#endif #endif
@ -112,7 +113,12 @@ struct console_state {
bool multiline_input = false; bool multiline_input = false;
bool use_color = false; bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT; console_color_t color = CONSOLE_COLOR_DEFAULT;
#if !defined (_WIN32)
FILE* out = stdout;
#if defined (_WIN32)
void* hConsole;
#else
FILE* tty = nullptr;
termios prev_state; termios prev_state;
#endif #endif
}; };