From 6eee6b6949de9897edbe91ef6f59f7c8df4a4074 Mon Sep 17 00:00:00 2001 From: crasm Date: Sun, 10 Sep 2023 06:24:31 -0400 Subject: [PATCH] main: add progress spinner on context swap --- examples/main/main.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index baec6ba12..0b83b7736 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,6 +5,7 @@ #include "build-info.h" #include "grammar-parser.h" +#include #include #include #include @@ -15,6 +16,7 @@ #include #include #include +#include #include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -40,6 +42,7 @@ static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; +static std::atomic is_swapping(false); // for spinner progress indicator void write_logfile( const llama_context * ctx, const gpt_params & params, const llama_model * model, @@ -85,6 +88,15 @@ void write_logfile( fclose(logfile); } +void display_spinner() { + const char spinner[] = { '/', '-', '\\', '|' }; + size_t i = 0; + while (is_swapping) { + std::cerr << spinner[i++ % sizeof(spinner)] << '\b'; + std::this_thread::sleep_for(std::chrono::milliseconds(125)); + } +} + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { if (signo == SIGINT) { @@ -523,6 +535,8 @@ int main(int argc, char ** argv) { break; } + is_swapping = true; // enable spinner + const int n_left = n_past - params.n_keep; LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); @@ -541,6 +555,8 @@ int main(int argc, char ** argv) { path_session.clear(); } + std::thread spinner_thread(display_spinner); + // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; @@ -626,6 +642,10 @@ int main(int argc, char ** argv) { session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); n_session_consumed = session_tokens.size(); } + + // Disable spinner + is_swapping = false; + spinner_thread.join(); } embd.clear();