diff --git a/.gitignore b/.gitignore index 21c575989..55bd30f9c 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ zig-out/ zig-cache/ ppl-*.txt +qnt-*.txt examples/jeopardy/results.txt koboldcpp.so diff --git a/convert.py b/convert.py index 126beaabc..8f4f0399e 100644 --- a/convert.py +++ b/convert.py @@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) description = f'safetensors begin={begin} end={end} type={data_type} path={path}' return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items()} + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) @@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] - files = [file for glob in globs for file in path.glob(glob)] + # Check if it's a set of safetensors files first + files = list(path.glob("model-00001-of-*.safetensors")) + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML # model and a GGML model exist in the same directory, we assume the diff --git a/examples/common.cpp b/examples/common.cpp index f1c3bae13..23d69e7d5 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,20 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); -#define CP_UTF8 65001 +#else +#include +#include +#include #endif int32_t get_num_physical_cores() { @@ -269,6 +265,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -359,6 +357,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -438,8 +437,8 @@ std::string gpt_random_prompt(std::mt19937 & rng) { // TODO: not great allocating this every time std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars - std::vector res(text.size() + (int)add_bos); - int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); + std::vector res(text.size() + (int) add_bos); + const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); assert(n >= 0); res.resize(n); @@ -479,54 +478,339 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - -#if defined (_WIN32) -void win32_console_init(bool enable_color) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; +void console_init(console_state & con_st) { +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); + } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } +#endif + setlocale(LC_ALL, ""); +} + +void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + +#if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif +} + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } } -// Convert a wide Unicode string to an UTF8 string -void win32_utf8_encode(const std::wstring & wstr, std::string & str) { - int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); - std::string strTo(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); - str = strTo; -} +char32_t getchar32() { + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } #endif + + return static_cast(wc); +} + +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + +void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + +// Helper function to remove the last UTF-8 character from a string +void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + } + line.erase(pos); +} + +bool console_readline(console_state & con_st, std::string & line) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(con_st.out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + replace_last(con_st, line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(con_st, ' '); + pop_cursor(con_st); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + replace_last(con_st, line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.multiline_input; + if (is_special_char) { + replace_last(con_st, ' '); + pop_cursor(con_st); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', con_st.out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(con_st); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', con_st.out); + } + } + + fflush(con_st.out); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516f..43f1cc9ef 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,11 @@ #include #include +#if !defined (_WIN32) +#include +#include +#endif + // // CLI argument parsing // @@ -56,6 +61,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +110,20 @@ enum console_color_t { }; struct console_state { + bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5ac151e14..6e1172a48 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.multiline_input = params.multiline_input; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -313,7 +321,8 @@ int main(int argc, char ** argv) { if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; - n_past = params.n_keep; + // always keep the first token - BOS + n_past = std::max(1, params.n_keep); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -331,7 +340,6 @@ int main(int argc, char ** argv) { } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) - // REVIEW if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 299a19999..9212dee5c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -25,46 +25,68 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval auto tokens = ::llama_tokenize(ctx, params.prompt, true); - int count = 0; - int seq_count = tokens.size() / params.n_ctx; - int n_vocab = llama_n_vocab(ctx); + int count = 0; + + const int n_chunk = tokens.size() / params.n_ctx; + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; double nll = 0.0; - fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); - for (int i = 0; i < seq_count; ++i) { - int start = i * params.n_ctx; - int end = start + params.n_ctx; + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.n_ctx; + const int end = start + params.n_ctx; + + const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; std::vector logits; - int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch; - auto start_t = std::chrono::high_resolution_clock::now(); + + const auto t_start = std::chrono::high_resolution_clock::now(); + for (int j = 0; j < num_batches; ++j) { - int batch_start = start + j * params.n_batch; - int batch_size = std::min(end - batch_start, params.n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (j == 0) { + tokens[batch_start] = llama_token_bos(); + } + + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - auto batch_logits = llama_get_logits(ctx); + + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + + const auto batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - auto end_t = std::chrono::high_resolution_clock::now(); + + const auto t_end = std::chrono::high_resolution_clock::now(); + if (i == 0) { - const float seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA ", seconds); - int total_seconds = (int)(seconds * seq_count); + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { - printf("%d hours ", total_seconds / (60*60)); + fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - printf("%d minutes\n", total_seconds / 60); + fprintf(stderr, "%d minutes\n", total_seconds / 60); } + // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half the window (so the model always has + // calculate the perplexity over the last half of the window (so the model always has // some context to predict the token). // // We rely on the fact that attention in the forward pass only looks at previous @@ -76,10 +98,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // process the entire prompt. for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. - std::vector tok_logits( - logits.begin() + j * n_vocab, + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, logits.begin() + (j + 1) * n_vocab); - float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); ++count; } diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index f0a12bcaf..775e1e44a 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -264,7 +264,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } //determine mem per token - const std::vector tmp = {0, 1, 2, 3}; + const std::vector tmp = {1, 2, 3, 4}; llama_eval(llama_ctx_v1, tmp.data(), tmp.size(), 0, params.n_threads); return ModelLoadResult::SUCCESS; diff --git a/llama.cpp b/llama.cpp index dba9fdaab..caceeed18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -980,8 +980,6 @@ static void llama_model_load_internal( // prepare memory for the weights { - const auto & hparams = model.hparams; - const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; @@ -1062,6 +1060,13 @@ static bool llama_eval_internal( const int n_tokens, const int n_past, const int n_threads) { + + // enforce that the first token is BOS + if (n_past == 0 && tokens[0] != llama_token_bos()) { + fprintf(stderr, "%s: first token must be BOS\n", __func__); + // return false; //never fail. Not even in the face of Armageddon. + } + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1492,7 +1497,7 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co } if (bos) { - output.push_back(1); + output.push_back(llama_token_bos()); } tokenizer.tokenize(text, output); @@ -2738,11 +2743,14 @@ int llama_eval( fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } + // get a more accurate load time, upon first eval + // TODO: fix this if (!ctx->has_evaluated_once) { ctx->t_load_us = ggml_time_us() - ctx->t_start_us; ctx->has_evaluated_once = true; } + return 0; } diff --git a/scripts/ppl-run-all.sh b/scripts/ppl-run-all.sh new file mode 100755 index 000000000..28f31ca71 --- /dev/null +++ b/scripts/ppl-run-all.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# +# quantize +# + +# 7B +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-7b-q4_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-7b-q4_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-7b-q4_2.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-7b-q5_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-7b-q5_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-7b-q8_0.txt + +# 13B +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-13b-q4_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-13b-q4_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-13b-q4_2.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-13b-q5_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-13b-q5_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-13b-q8_0.txt + +# +# perplexity +# + +# 7B +time ./bin/perplexity -m ../models/7B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-f16.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_2.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q8_0.txt + +# 13B +time ./bin/perplexity -m ../models/13B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-f16.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_2.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q8_0.txt