main: add progress spinner on context swap

This commit is contained in:
crasm 2023-09-10 06:24:31 -04:00
parent 21ac3a1503
commit 6eee6b6949

View file

@ -5,6 +5,7 @@
#include "build-info.h" #include "build-info.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#include <atomic>
#include <cassert> #include <cassert>
#include <cinttypes> #include <cinttypes>
#include <cmath> #include <cmath>
@ -15,6 +16,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <thread>
#include <vector> #include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@ -40,6 +42,7 @@ static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss; static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens; static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false; static bool is_interacting = false;
static std::atomic<bool> is_swapping(false); // for spinner progress indicator
void write_logfile( void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model, const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -85,6 +88,15 @@ void write_logfile(
fclose(logfile); fclose(logfile);
} }
void display_spinner() {
const char spinner[] = { '/', '-', '\\', '|' };
size_t i = 0;
while (is_swapping) {
std::cerr << spinner[i++ % sizeof(spinner)] << '\b';
std::this_thread::sleep_for(std::chrono::milliseconds(125));
}
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
if (signo == SIGINT) { if (signo == SIGINT) {
@ -523,6 +535,8 @@ int main(int argc, char ** argv) {
break; break;
} }
is_swapping = true; // enable spinner
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
@ -541,6 +555,8 @@ int main(int argc, char ** argv) {
path_session.clear(); path_session.clear();
} }
std::thread spinner_thread(display_spinner);
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if (n_session_consumed < (int) session_tokens.size()) { if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0; size_t i = 0;
@ -626,6 +642,10 @@ int main(int argc, char ** argv) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size(); n_session_consumed = session_tokens.size();
} }
// Disable spinner
is_swapping = false;
spinner_thread.join();
} }
embd.clear(); embd.clear();