Add author mode and other related QOL improvements

This commit is contained in:
Danny Daemonic 2023-04-18 02:55:40 -07:00
parent a90e96b266
commit 52e319050b
3 changed files with 184 additions and 48 deletions

View file

@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
const wchar_t * lpWideCharStr, int cchWideChar,
char * lpMultiByteStr, int cbMultiByte,
const char * lpDefaultChar, bool * lpUsedDefaultChar);
#define ENABLE_LINE_INPUT 0x0002
#define ENABLE_ECHO_INPUT 0x0004
#define CP_UTF8 65001
#define CONSOLE_CHAR_TYPE wchar_t
#define CONSOLE_GET_CHAR() getwchar()
#define CONSOLE_EOF WEOF
#else
#include <unistd.h>
#define CONSOLE_CHAR_TYPE char
#define CONSOLE_GET_CHAR() getchar()
#define CONSOLE_EOF EOF
#endif
int32_t get_num_physical_cores() {
@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.embedding = true;
} else if (arg == "--interactive-first") {
params.interactive_first = true;
} else if (arg == "--author-mode") {
params.author_mode = true;
} else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true;
} else if (arg == "--color") {
@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
}
/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
void console_set_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
switch(color) {
case CONSOLE_COLOR_DEFAULT:
@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
}
}
void console_init(console_state & con_st) {
#if defined (_WIN32)
void win32_console_init(bool enable_color) {
// Windows-specific console initialization
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) {
}
if (hConOut) {
// Enable ANSI colors on Windows 10+
if (enable_color && !(dwMode & 0x4)) {
if (con_st.use_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
// Set console output codepage to UTF8
@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) {
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF16
_setmode(_fileno(stdin), _O_WTEXT);
// Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
SetConsoleMode(hConIn, dwMode);
}
#else
// POSIX-specific console initialization
struct termios new_termios;
tcgetattr(STDIN_FILENO, &con_st.prev_state);
new_termios = con_st.prev_state;
new_termios.c_lflag &= ~(ICANON | ECHO);
new_termios.c_cc[VMIN] = 1;
new_termios.c_cc[VTIME] = 0;
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
#endif
}
void console_cleanup(console_state & con_st) {
#if !defined(_WIN32)
// Restore the terminal settings on POSIX systems
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
#endif
// Reset console color
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// Helper function to remove the last UTF-8 character from a string
void remove_last_utf8_char(std::string & line) {
if (line.empty()) return;
size_t pos = line.length() - 1;
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
}
line.erase(pos);
}
#if defined (_WIN32)
// Convert a wide Unicode string to an UTF8 string
void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
str = strTo;
}
#endif
bool console_readline(console_state & con_st, std::string & line) {
line.clear();
bool is_special_char = false;
bool end_of_stream = false;
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
CONSOLE_CHAR_TYPE input_char;
while (true) {
fflush(stdout); // Ensure all output is displayed before waiting for input
input_char = CONSOLE_GET_CHAR();
if (input_char == '\r' || input_char == '\n') {
break;
}
if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) {
end_of_stream = true;
break;
}
if (is_special_char) {
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
putchar('\b');
putchar(line.back());
is_special_char = false;
}
if (input_char == '\033') { // Escape sequence
CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR();
if (code == '[') {
// Discard the rest of the escape sequence
while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) {
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
break;
}
}
}
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
if (!line.empty()) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
remove_last_utf8_char(line);
}
} else if (input_char < 32) {
// Ignore control characters
} else {
#if defined(_WIN32)
std::string utf8_char;
win32_utf8_encode(std::wstring(1, input_char), utf8_char);
line += utf8_char;
fputs(utf8_char.c_str(), stdout);
#else
line += input_char;
putchar(input_char);
#endif
}
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
putchar('\b');
putchar(line.back());
is_special_char = true;
}
}
bool has_more = con_st.author_mode;
if (is_special_char) {
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
char last = line.back();
line.pop_back();
if (last == '\\') {
line += '\n';
putchar('\n');
has_more = !has_more;
} else {
// llama doesn't seem to process a single space
if (line.length() == 1 && line.back() == ' ') {
line.clear();
putchar('\b');
}
has_more = false;
}
} else {
if (end_of_stream) {
has_more = false;
} else {
line += '\n';
putchar('\n');
}
}
fflush(stdout);
return has_more;
}

View file

@ -10,6 +10,10 @@
#include <thread>
#include <unordered_map>
#if !defined (_WIN32)
#include <termios.h>
#endif
//
// CLI argument parsing
//
@ -56,6 +60,7 @@ struct gpt_params {
bool embedding = false; // get only sentence embedding
bool interactive_first = false; // wait for user input immediately
bool author_mode = false; // reverse the usage of `\`
bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token
@ -104,13 +109,15 @@ enum console_color_t {
};
struct console_state {
bool author_mode = false;
bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT;
#if !defined (_WIN32)
termios prev_state;
#endif
};
void set_console_color(console_state & con_st, console_color_t color);
#if defined (_WIN32)
void win32_console_init(bool enable_color);
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
#endif
void console_init(console_state & con_st);
void console_cleanup(console_state & con_st);
void console_set_color(console_state & con_st, console_color_t color);
bool console_readline(console_state & con_st, std::string & line);

View file

@ -35,12 +35,12 @@ static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
} else {
console_cleanup(con_st);
printf("\n");
llama_print_timings(*g_ctx);
_exit(130);
}
@ -59,10 +59,9 @@ int main(int argc, char ** argv) {
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
con_st.use_color = params.use_color;
#if defined (_WIN32)
win32_console_init(params.use_color);
#endif
con_st.author_mode = params.author_mode;
console_init(con_st);
atexit([]() { console_cleanup(con_st); });
if (params.perplexity) {
printf("\n************\n");
@ -275,12 +274,21 @@ int main(int argc, char ** argv) {
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) {
const char *control_message;
if (con_st.author_mode) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
control_message = " - Press Return to return control to LLaMa.\n"
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n"
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
"%s\n", control_message);
is_interacting = params.interactive_first;
}
@ -299,7 +307,7 @@ int main(int argc, char ** argv) {
int n_session_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
@ -498,7 +506,7 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (input_echo && (int)embd_inp.size() == n_consumed) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// in interactive mode, and not currently processing queued inputs;
@ -518,17 +526,12 @@ int main(int argc, char ** argv) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
is_antiprompt = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
break;
}
}
}
if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (params.instruct) {
printf("\n> ");
}
@ -542,31 +545,12 @@ int main(int argc, char ** argv) {
std::string line;
bool another_line = true;
do {
#if defined(_WIN32)
std::wstring wline;
if (!std::getline(std::wcin, wline)) {
// input stream is bad or EOF received
return 0;
}
win32_utf8_encode(wline, line);
#else
if (!std::getline(std::cin, line)) {
// input stream is bad or EOF received
return 0;
}
#endif
if (!line.empty()) {
if (line.back() == '\\') {
line.pop_back(); // Remove the continue character
} else {
another_line = false;
}
buffer += line + '\n'; // Append the line to the result
}
another_line = console_readline(con_st, line);
buffer += line;
} while (another_line);
// done taking input, reset color
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
@ -622,7 +606,5 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 0;
}