Add author mode and other related QOL improvements

This commit is contained in:
Danny Daemonic 2023-04-18 02:55:40 -07:00
parent a90e96b266
commit 52e319050b
3 changed files with 184 additions and 48 deletions

View file

@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
const wchar_t * lpWideCharStr, int cchWideChar, const wchar_t * lpWideCharStr, int cchWideChar,
char * lpMultiByteStr, int cbMultiByte, char * lpMultiByteStr, int cbMultiByte,
const char * lpDefaultChar, bool * lpUsedDefaultChar); const char * lpDefaultChar, bool * lpUsedDefaultChar);
#define ENABLE_LINE_INPUT 0x0002
#define ENABLE_ECHO_INPUT 0x0004
#define CP_UTF8 65001 #define CP_UTF8 65001
#define CONSOLE_CHAR_TYPE wchar_t
#define CONSOLE_GET_CHAR() getwchar()
#define CONSOLE_EOF WEOF
#else
#include <unistd.h>
#define CONSOLE_CHAR_TYPE char
#define CONSOLE_GET_CHAR() getchar()
#define CONSOLE_EOF EOF
#endif #endif
int32_t get_num_physical_cores() { int32_t get_num_physical_cores() {
@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.embedding = true; params.embedding = true;
} else if (arg == "--interactive-first") { } else if (arg == "--interactive-first") {
params.interactive_first = true; params.interactive_first = true;
} else if (arg == "--author-mode") {
params.author_mode = true;
} else if (arg == "-ins" || arg == "--instruct") { } else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true; params.instruct = true;
} else if (arg == "--color") { } else if (arg == "--color") {
@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " specified more than once for multiple prompts).\n");
@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
} }
/* Keep track of current color of output, and emit ANSI code if it changes. */ /* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) { void console_set_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) { if (con_st.use_color && con_st.color != color) {
switch(color) { switch(color) {
case CONSOLE_COLOR_DEFAULT: case CONSOLE_COLOR_DEFAULT:
@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
} }
} }
void console_init(console_state & con_st) {
#if defined (_WIN32) #if defined (_WIN32)
void win32_console_init(bool enable_color) { // Windows-specific console initialization
unsigned long dwMode = 0; unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) {
} }
if (hConOut) { if (hConOut) {
// Enable ANSI colors on Windows 10+ // Enable ANSI colors on Windows 10+
if (enable_color && !(dwMode & 0x4)) { if (con_st.use_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
} }
// Set console output codepage to UTF8 // Set console output codepage to UTF8
@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) {
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF16 // Set console input codepage to UTF16
_setmode(_fileno(stdin), _O_WTEXT); _setmode(_fileno(stdin), _O_WTEXT);
// Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
SetConsoleMode(hConIn, dwMode);
} }
#else
// POSIX-specific console initialization
struct termios new_termios;
tcgetattr(STDIN_FILENO, &con_st.prev_state);
new_termios = con_st.prev_state;
new_termios.c_lflag &= ~(ICANON | ECHO);
new_termios.c_cc[VMIN] = 1;
new_termios.c_cc[VTIME] = 0;
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
#endif
} }
void console_cleanup(console_state & con_st) {
#if !defined(_WIN32)
// Restore the terminal settings on POSIX systems
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
#endif
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// Helper function to remove the last UTF-8 character from a string
void remove_last_utf8_char(std::string & line) {
if (line.empty()) return;
size_t pos = line.length() - 1;
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
}
line.erase(pos);
}
#if defined (_WIN32)
// Convert a wide Unicode string to an UTF8 string // Convert a wide Unicode string to an UTF8 string
void win32_utf8_encode(const std::wstring & wstr, std::string & str) { void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
str = strTo; str = strTo;
} }
#endif #endif
bool console_readline(console_state & con_st, std::string & line) {
line.clear();
bool is_special_char = false;
bool end_of_stream = false;
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
CONSOLE_CHAR_TYPE input_char;
while (true) {
fflush(stdout); // Ensure all output is displayed before waiting for input
input_char = CONSOLE_GET_CHAR();
if (input_char == '\r' || input_char == '\n') {
break;
}
if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) {
end_of_stream = true;
break;
}
if (is_special_char) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
putchar('\b');
putchar(line.back());
is_special_char = false;
}
if (input_char == '\033') { // Escape sequence
CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR();
if (code == '[') {
// Discard the rest of the escape sequence
while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) {
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
break;
}
}
}
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
if (!line.empty()) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
remove_last_utf8_char(line);
}
} else if (input_char < 32) {
// Ignore control characters
} else {
#if defined(_WIN32)
std::string utf8_char;
win32_utf8_encode(std::wstring(1, input_char), utf8_char);
line += utf8_char;
fputs(utf8_char.c_str(), stdout);
#else
line += input_char;
putchar(input_char);
#endif
}
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
putchar('\b');
putchar(line.back());
is_special_char = true;
}
}
bool has_more = con_st.author_mode;
if (is_special_char) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
char last = line.back();
line.pop_back();
if (last == '\\') {
line += '\n';
putchar('\n');
has_more = !has_more;
} else {
// llama doesn't seem to process a single space
if (line.length() == 1 && line.back() == ' ') {
line.clear();
putchar('\b');
}
has_more = false;
}
} else {
if (end_of_stream) {
has_more = false;
} else {
line += '\n';
putchar('\n');
}
}
fflush(stdout);
return has_more;
}

View file

@ -10,6 +10,10 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#if !defined (_WIN32)
#include <termios.h>
#endif
// //
// CLI argument parsing // CLI argument parsing
// //
@ -56,6 +60,7 @@ struct gpt_params {
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding
bool interactive_first = false; // wait for user input immediately bool interactive_first = false; // wait for user input immediately
bool author_mode = false; // reverse the usage of `\`
bool instruct = false; // instruction mode (used for Alpaca models) bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token bool penalize_nl = true; // consider newlines as a repeatable token
@ -104,13 +109,15 @@ enum console_color_t {
}; };
struct console_state { struct console_state {
bool author_mode = false;
bool use_color = false; bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT; console_color_t color = CONSOLE_COLOR_DEFAULT;
#if !defined (_WIN32)
termios prev_state;
#endif
}; };
void set_console_color(console_state & con_st, console_color_t color); void console_init(console_state & con_st);
void console_cleanup(console_state & con_st);
#if defined (_WIN32) void console_set_color(console_state & con_st, console_color_t color);
void win32_console_init(bool enable_color); bool console_readline(console_state & con_st, std::string & line);
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
#endif

View file

@ -35,12 +35,12 @@ static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting) {
is_interacting=true; is_interacting=true;
} else { } else {
console_cleanup(con_st);
printf("\n");
llama_print_timings(*g_ctx); llama_print_timings(*g_ctx);
_exit(130); _exit(130);
} }
@ -59,10 +59,9 @@ int main(int argc, char ** argv) {
// save choice to use color for later // save choice to use color for later
// (note for later: this is a slightly awkward choice) // (note for later: this is a slightly awkward choice)
con_st.use_color = params.use_color; con_st.use_color = params.use_color;
con_st.author_mode = params.author_mode;
#if defined (_WIN32) console_init(con_st);
win32_console_init(params.use_color); atexit([]() { console_cleanup(con_st); });
#endif
if (params.perplexity) { if (params.perplexity) {
printf("\n************\n"); printf("\n************\n");
@ -275,12 +274,21 @@ int main(int argc, char ** argv) {
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) { if (params.interactive) {
const char *control_message;
if (con_st.author_mode) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
control_message = " - Press Return to return control to LLaMa.\n"
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
#endif #endif
" - Press Return to return control to LLaMa.\n" "%s\n", control_message);
" - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_first; is_interacting = params.interactive_first;
} }
@ -299,7 +307,7 @@ int main(int argc, char ** argv) {
int n_session_consumed = 0; int n_session_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly // the first thing we will do is to output the prompt, so set color accordingly
set_console_color(con_st, CONSOLE_COLOR_PROMPT); console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd; std::vector<llama_token> embd;
@ -498,7 +506,7 @@ int main(int argc, char ** argv) {
} }
// reset color to default if we there is no pending user input // reset color to default if we there is no pending user input
if (input_echo && (int)embd_inp.size() == n_consumed) { if (input_echo && (int)embd_inp.size() == n_consumed) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
} }
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
@ -518,17 +526,12 @@ int main(int argc, char ** argv) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true; is_interacting = true;
is_antiprompt = true; is_antiprompt = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
break; break;
} }
} }
} }
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (params.instruct) { if (params.instruct) {
printf("\n> "); printf("\n> ");
} }
@ -542,31 +545,12 @@ int main(int argc, char ** argv) {
std::string line; std::string line;
bool another_line = true; bool another_line = true;
do { do {
#if defined(_WIN32) another_line = console_readline(con_st, line);
std::wstring wline; buffer += line;
if (!std::getline(std::wcin, wline)) {
// input stream is bad or EOF received
return 0;
}
win32_utf8_encode(wline, line);
#else
if (!std::getline(std::cin, line)) {
// input stream is bad or EOF received
return 0;
}
#endif
if (!line.empty()) {
if (line.back() == '\\') {
line.pop_back(); // Remove the continue character
} else {
another_line = false;
}
buffer += line + '\n'; // Append the line to the result
}
} while (another_line); } while (another_line);
// done taking input, reset color // done taking input, reset color
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
// Add tokens to embd only if the input buffer is non-empty // Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back // Entering a empty line lets the user pass control back
@ -622,7 +606,5 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 0; return 0;
} }