Works with all characters and control codes + Windows console fixes

This commit is contained in:
Danny Daemonic 2023-05-07 02:39:10 -07:00
parent 534c89e766
commit 5cc9085353
2 changed files with 168 additions and 80 deletions

View file

@ -14,14 +14,16 @@
#include <sys/sysctl.h>
#endif
#if defined (_WIN32)
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <fcntl.h>
#include <io.h>
#else
#include <sys/ioctl.h>
#include <unistd.h>
#include <wchar.h>
#endif
int32_t get_num_physical_cores() {
@ -473,45 +475,27 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
return lctx;
}
/* Keep track of current color of output, and emit ANSI code if it changes. */
void console_set_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
switch(color) {
case CONSOLE_COLOR_DEFAULT:
printf(ANSI_COLOR_RESET);
break;
case CONSOLE_COLOR_PROMPT:
printf(ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
}
}
void console_init(console_state & con_st) {
#if defined (_WIN32)
#if defined(_WIN32)
// Windows-specific console initialization
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
hConOut = 0;
DWORD dwMode = 0;
con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) {
con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE);
if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) {
con_st.hConsole = NULL;
}
}
if (hConOut) {
if (con_st.hConsole) {
// Enable ANSI colors on Windows 10+
if (con_st.use_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
}
// Set console output codepage to UTF8
SetConsoleOutputCP(CP_UTF8);
}
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF16
_setmode(_fileno(stdin), _O_WTEXT);
@ -528,46 +512,49 @@ void console_init(console_state & con_st) {
new_termios.c_cc[VMIN] = 1;
new_termios.c_cc[VTIME] = 0;
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
con_st.tty = fopen("/dev/tty", "w+");
if (con_st.tty != nullptr) {
con_st.out = con_st.tty;
}
#endif
setlocale(LC_ALL, "");
}
void console_cleanup(console_state & con_st) {
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
#if !defined(_WIN32)
if (con_st.tty != nullptr) {
con_st.out = stdout;
fclose(con_st.tty);
con_st.tty = nullptr;
}
// Restore the terminal settings on POSIX systems
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
#endif
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
#if defined (_WIN32)
int puts_get_width(_In_z_ CONST CHAR* lpBuffer) {
DWORD nNumberOfCharsToWrite = strlen(lpBuffer);
HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
// Make a guess
return 1;
/* Keep track of current color of output, and emit ANSI code if it changes. */
void console_set_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
fflush(stdout);
switch(color) {
case CONSOLE_COLOR_DEFAULT:
fprintf(con_st.out, ANSI_COLOR_RESET);
break;
case CONSOLE_COLOR_PROMPT:
fprintf(con_st.out, ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
fflush(con_st.out);
}
COORD initialPosition = bufferInfo.dwCursorPosition;
DWORD written = 0;
WriteConsole(hConsole, lpBuffer, nNumberOfCharsToWrite, &written, nullptr);
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
if (newBufferInfo.dwCursorPosition.Y > initialPosition.Y) {
width += (newBufferInfo.dwSize.X - initialPosition.X);
}
return width;
}
#endif
char32_t getchar32() {
wchar_t wc = getwchar();
@ -590,6 +577,102 @@ char32_t getchar32() {
return static_cast<char32_t>(wc);
}
void pop_cursor(console_state & con_st) {
#if defined(_WIN32)
if (con_st.hConsole != NULL) {
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo);
COORD newCursorPosition = bufferInfo.dwCursorPosition;
if (newCursorPosition.X == 0) {
newCursorPosition.X = bufferInfo.dwSize.X - 1;
newCursorPosition.Y -= 1;
} else {
newCursorPosition.X -= 1;
}
SetConsoleCursorPosition(con_st.hConsole, newCursorPosition);
return;
}
#endif
putc('\b', con_st.out);
}
int estimateWidth(char32_t codepoint) {
#if defined(_WIN32)
return 1;
#else
return wcwidth(codepoint);
#endif
}
int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) {
#if defined(_WIN32)
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) {
// go with the default
return expectedWidth;
}
COORD initialPosition = bufferInfo.dwCursorPosition;
DWORD nNumberOfChars = length;
WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
// Figure out our real position if we're in the last column
if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
DWORD nNumberOfChars;
WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL);
GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo);
}
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
if (width < 0) {
width += newBufferInfo.dwSize.X;
}
return width;
#else
// we can trust expectedWidth if we've got one
if (expectedWidth >= 0 || con_st.tty == nullptr) {
fwrite(utf8_codepoint, length, 1, con_st.out);
return expectedWidth;
}
fputs("\033[6n", con_st.tty); // Query cursor position
int x1, x2, y1, y2;
int results = 0;
results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1);
fwrite(utf8_codepoint, length, 1, con_st.tty);
fputs("\033[6n", con_st.tty); // Query cursor position
results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2);
if (results != 4) {
return expectedWidth;
}
int width = x2 - x1;
if (width < 0) {
// Calculate the width considering text wrapping
struct winsize w;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
width += w.ws_col;
}
return width;
#endif
}
void replace_last(console_state & con_st, char ch) {
#if defined(_WIN32)
pop_cursor(con_st);
put_codepoint(con_st, &ch, 1, 1);
#else
fprintf(con_st.out, "\b%c", ch);
#endif
}
void append_utf8(char32_t ch, std::string & out) {
if (ch <= 0x7F) {
out.push_back(static_cast<unsigned char>(ch));
@ -627,6 +710,9 @@ void pop_back_utf8_char(std::string & line) {
bool console_readline(console_state & con_st, std::string & line) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (con_st.out != stdout) {
fflush(stdout);
}
line.clear();
std::vector<int> widths;
@ -635,7 +721,7 @@ bool console_readline(console_state & con_st, std::string & line) {
char32_t input_char;
while (true) {
fflush(stdout); // Ensure all output is displayed before waiting for input
fflush(con_st.out); // Ensure all output is displayed before waiting for input
input_char = getchar32();
if (input_char == '\r' || input_char == '\n') {
@ -649,8 +735,7 @@ bool console_readline(console_state & con_st, std::string & line) {
if (is_special_char) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
putchar('\b');
putchar(line.back());
replace_last(con_st, line.back());
is_special_char = false;
}
@ -670,50 +755,47 @@ bool console_readline(console_state & con_st, std::string & line) {
do {
count = widths.back();
widths.pop_back();
// Move cursor back, print spaces, and move cursor back again
// Move cursor back, print space, and move cursor back again
for (int i = 0; i < count; i++) {
fputs("\b \b", stdout);
replace_last(con_st, ' ');
pop_cursor(con_st);
}
pop_back_utf8_char(line);
} while (count == 0 && !widths.empty());
}
} else if (input_char < 32) {
// Ignore control characters
} else {
int offset = line.length();
append_utf8(input_char, line);
#if defined (_WIN32)
int width = puts_get_width(line.c_str() + offset);
int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
if (width < 0) {
width = 0;
}
widths.push_back(width);
#else
fputs(line.c_str() + offset, stdout);
widths.push_back(wcwidth(input_char));
#endif
}
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
putchar('\b');
putchar(line.back());
replace_last(con_st, line.back());
is_special_char = true;
}
}
bool has_more = con_st.multiline_input;
if (is_special_char) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
replace_last(con_st, ' ');
pop_cursor(con_st);
char last = line.back();
line.pop_back();
if (last == '\\') {
line += '\n';
putchar('\n');
fputc('\n', con_st.out);
has_more = !has_more;
} else {
// llama will just eat the single space
// llama will just eat the single space, it won't act as a space
if (line.length() == 1 && line.back() == ' ') {
line.clear();
putchar('\b');
pop_cursor(con_st);
}
has_more = false;
}
@ -722,10 +804,10 @@ bool console_readline(console_state & con_st, std::string & line) {
has_more = false;
} else {
line += '\n';
putchar('\n');
fputc('\n', con_st.out);
}
}
fflush(stdout);
fflush(con_st.out);
return has_more;
}

View file

@ -11,6 +11,7 @@
#include <unordered_map>
#if !defined (_WIN32)
#include <stdio.h>
#include <termios.h>
#endif
@ -112,7 +113,12 @@ struct console_state {
bool multiline_input = false;
bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT;
#if !defined (_WIN32)
FILE* out = stdout;
#if defined (_WIN32)
void* hConsole;
#else
FILE* tty = nullptr;
termios prev_state;
#endif
};