Introduce -q (quiet flag) and improve ctrl-c ux

This commit is contained in:
Justine Tunney 2023-05-12 09:46:07 -07:00
parent e8de1e4766
commit 45186c74ac
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
5 changed files with 43 additions and 29 deletions

View file

@ -60,6 +60,7 @@ static inline void __oncrash(int sig, struct siginfo *si, void *arg) {
}
static void __got_sigquit(int sig, struct siginfo *si, void *arg) {
write(2, "^\\", 2);
__oncrash(sig, si, arg);
}
static void __got_sigfpe(int sig, struct siginfo *si, void *arg) {

View file

@ -109,6 +109,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.seed = std::stoi(argv[i]);
} else if (arg == "-v" || arg == "--verbose") {
++params.verbose;
} else if (arg == "-q" || arg == "--quiet") {
--params.verbose;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
@ -332,7 +334,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// if no prompt is specified, then use companion ai
if (params.prompt.empty()) {
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: No prompt specified\n", __func__);
fprintf(stderr, "%s: Loading CompanionAI\n", __func__);
}
@ -368,7 +370,8 @@ void gpt_print_usage(FILE *f, int /*argc*/, char ** argv, const gpt_params & par
fprintf(f, "\n");
fprintf(f, "options:\n");
fprintf(f, " -h, --help show this help message and exit\n");
fprintf(f, " -v, --verbose print plenty of helpful information, e.g. prompt\n");
fprintf(f, " -v, --verbose print helpful information to stderr [repeatable]\n");
fprintf(f, " -s, --silent disables ephemeral progress indicators [repeatable]\n");
fprintf(f, " -i, --interactive run in interactive mode\n");
fprintf(f, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(f, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");

View file

@ -25,7 +25,7 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size
int32_t n_batch = 64; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_batch = 32; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
// sampling parameters

View file

@ -936,7 +936,7 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx;
}
if (verbose) {
if (verbose > 0) {
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
@ -959,7 +959,7 @@ static void llama_model_load_internal(
size_t ctx_size, mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
if (verbose) {
if (verbose > 0) {
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
}
@ -979,7 +979,7 @@ static void llama_model_load_internal(
const size_t mem_required_state =
scale*MEM_REQ_KV_SELF().at(model.type);
if (verbose) {
if (verbose > 0) {
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
}
@ -2108,7 +2108,7 @@ struct llama_context * llama_init_from_file(
}
unsigned cur_percentage = 0;
if (verbose && params.progress_callback == NULL) {
if (verbose > 0 && params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;
params.progress_callback = [](float progress, void * ctx) {
unsigned * cur_percentage_p = (unsigned *) ctx;
@ -2146,7 +2146,7 @@ struct llama_context * llama_init_from_file(
return nullptr;
}
if (verbose) {
if (verbose > 0) {
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}

View file

@ -31,6 +31,7 @@
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/stat.h"
#include "libc/intrin/bits.h"
#include "libc/intrin/kprintf.h"
#include "libc/log/log.h"
#include "libc/macros.internal.h"
#include "libc/nexgen32e/x86feature.h"
@ -62,18 +63,25 @@ static console_state con_st;
////////////////////////////////////////////////////////////////////////////////
static std::atomic<bool> is_interacting;
static std::atomic<bool> is_stalled;
static std::atomic<bool> is_terminated;
static std::atomic<bool> is_interacting;
static void acknowledge_shutdown(void) {
write(2, "^C", 2);
}
static void sigint_handler_batch(int signo) {
is_terminated = true;
acknowledge_shutdown();
}
static void sigint_handler_interactive(int signo) {
if (!is_interacting) {
if (!is_interacting && !is_stalled) {
is_interacting = true;
} else {
is_terminated = true;
acknowledge_shutdown();
}
}
@ -223,7 +231,7 @@ int main(int argc, char ** argv) {
params.seed = time(NULL);
}
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
}
@ -258,7 +266,7 @@ int main(int argc, char ** argv) {
}
// print system information
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
@ -277,7 +285,7 @@ int main(int argc, char ** argv) {
llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
}
if (params.verbose) {
if (params.verbose > 0) {
llama_print_timings(ctx);
}
llama_free(ctx);
@ -365,22 +373,22 @@ int main(int argc, char ** argv) {
sigaction(SIGINT, &sa, NULL);
if (params.interactive) {
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: interactive mode on.\n", __func__);
}
if (params.verbose && params.antiprompt.size()) {
if (params.verbose > 0 && params.antiprompt.size()) {
for (auto antiprompt : params.antiprompt) {
fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
}
}
if (params.verbose && !params.input_prefix.empty()) {
if (params.verbose > 0 && !params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n",
@ -388,7 +396,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
}
if (params.verbose && params.interactive) {
if (params.verbose > 0 && params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n"
" - Press Ctrl+C to interject at any time.\n"
" - Press Return to return control to LLaMa.\n"
@ -399,7 +407,7 @@ int main(int argc, char ** argv) {
remember_init();
bool is_antiprompt = false;
bool input_noecho = !params.verbose;
bool input_noecho = params.verbose <= 0;
int n_past = 0;
int n_remain = params.n_predict;
@ -443,7 +451,7 @@ int main(int argc, char ** argv) {
// check expected state size
state_size = llama_get_state_size(ctx);
if (READ64LE(header->state_size) != state_size) {
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: prompt has stale data state size\n",
params.prompt_path.c_str());
}
@ -465,7 +473,7 @@ int main(int argc, char ** argv) {
mtim.tv_sec = READ64LE(header->model_mtim_sec);
mtim.tv_nsec = READ64LE(header->model_mtim_nsec);
if (CompareTime(model_stat.st_mtim, mtim) > 0) {
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: model file timestamp changed; will reload and regenerate prompt\n",
params.prompt_path.c_str());
}
@ -481,7 +489,7 @@ int main(int argc, char ** argv) {
// check prompt textus
if (prompt_size != params.prompt.size() ||
memcmp(header + 1, params.prompt.c_str(), prompt_size) != 0) {
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: prompt text changed; will reload and regenerate\n",
params.prompt_path.c_str());
}
@ -490,7 +498,7 @@ int main(int argc, char ** argv) {
// read the transformer state
llama_set_state_data(ctx, (uint8_t *)(header + 1) + prompt_size);
// we're finished loading the prompt file
if (params.verbose) {
if (params.verbose > 0) {
fprintf(stderr, "%s: %s: reloaded previously saved prompt\n",
__func__, params.prompt_path.c_str());
}
@ -508,7 +516,7 @@ int main(int argc, char ** argv) {
close(fd);
}
if (prompt_status == kPromptPending && params.verbose) {
if (prompt_status == kPromptPending && params.verbose > 0) {
// the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
}
@ -535,11 +543,13 @@ int main(int argc, char ** argv) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
is_stalled = n_eval > 1;
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
return 1;
}
is_stalled = false;
n_past += n_eval;
if (prompt_status == kPromptPending &&
!params.verbose && con_st.use_color && embd_inp.size()) {
@ -599,11 +609,11 @@ int main(int argc, char ** argv) {
llama_copy_state_data(ctx, (uint8_t *)map + sizeof(header) + params.prompt.size());
memcpy((uint8_t *)map + sizeof(header), params.prompt.c_str(), params.prompt.size());
memcpy(map, &header, sizeof(header));
if (msync(map, file_size, MS_ASYNC) && params.verbose) {
if (msync(map, file_size, MS_ASYNC) && params.verbose > 0) {
fprintf(stderr, "%s: msync failed: %s\n",
tmppath.c_str(), strerror(errno));
}
if (munmap(map, file_size) && params.verbose) {
if (munmap(map, file_size) && params.verbose > 0) {
fprintf(stderr, "%s: munmap failed: %s\n",
tmppath.c_str(), strerror(errno));
}
@ -877,7 +887,7 @@ int main(int argc, char ** argv) {
if (!embd.empty() && embd.back() == llama_token_eos()) {
if (params.instruct) {
is_interacting = true;
} else if (params.verbose) {
} else if (params.verbose > 0) {
fprintf(stderr, " [end of text]\n");
break;
}
@ -893,13 +903,13 @@ int main(int argc, char ** argv) {
if (is_terminated) {
console_cleanup(con_st);
printf("\n");
if (params.verbose) {
if (params.verbose > 0) {
llama_print_timings(ctx);
}
_exit(128 + SIGINT);
}
if (params.verbose) {
if (params.verbose > 0) {
llama_print_timings(ctx);
}
llama_free(ctx);