main: add progress spinner on context swap

This commit is contained in:
crasm 2023-09-10 06:24:31 -04:00
parent 21ac3a1503
commit 6eee6b6949

View file

@ -5,6 +5,7 @@
#include "build-info.h"
#include "grammar-parser.h"
#include <atomic>
#include <cassert>
#include <cinttypes>
#include <cmath>
@ -15,6 +16,7 @@
#include <iostream>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@ -40,6 +42,7 @@ static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static std::atomic<bool> is_swapping(false); // for spinner progress indicator
void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -85,6 +88,15 @@ void write_logfile(
fclose(logfile);
}
void display_spinner() {
const char spinner[] = { '/', '-', '\\', '|' };
size_t i = 0;
while (is_swapping) {
std::cerr << spinner[i++ % sizeof(spinner)] << '\b';
std::this_thread::sleep_for(std::chrono::milliseconds(125));
}
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
if (signo == SIGINT) {
@ -523,6 +535,8 @@ int main(int argc, char ** argv) {
break;
}
is_swapping = true; // enable spinner
const int n_left = n_past - params.n_keep;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
@ -541,6 +555,8 @@ int main(int argc, char ** argv) {
path_session.clear();
}
std::thread spinner_thread(display_spinner);
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
@ -626,6 +642,10 @@ int main(int argc, char ** argv) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
// Disable spinner
is_swapping = false;
spinner_thread.join();
}
embd.clear();