Make shell usability improvements to llama.cpp

- Introduce -v and --verbose flags
- Don't print stats / diagnostics unless -v is passed
- Reduce --top_p default from 0.95 to 0.70
- Change --reverse-prompt to no longer imply --interactive
- Permit --reverse-prompt specifying custom EOS if non-interactive
This commit is contained in:
Justine Tunney 2023-04-28 02:54:11 -07:00
parent 420f889ac3
commit 1c2da3a55a
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
6 changed files with 103 additions and 55 deletions

View file

@ -16,6 +16,11 @@ ORIGIN
LOCAL CHANGES LOCAL CHANGES
- Introduce -v and --verbose flags
- Don't print stats / diagnostics unless -v is passed
- Reduce --top_p default from 0.95 to 0.70
- Change --reverse-prompt to no longer imply --interactive
- Permit --reverse-prompt specifying custom EOS if non-interactive
- Refactor headers per cosmo convention - Refactor headers per cosmo convention
- Replace code like 'ggjt' with READ32BE("ggjt") - Replace code like 'ggjt' with READ32BE("ggjt")
- Remove C++ exceptions; use Die() function instead - Remove C++ exceptions; use Die() function instead

View file

@ -91,6 +91,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.seed = std::stoi(argv[i]); params.seed = std::stoi(argv[i]);
} else if (arg == "-v" || arg == "--verbose") {
++params.verbose;
} else if (arg == "-t" || arg == "--threads") { } else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;

View file

@ -17,6 +17,7 @@
struct gpt_params { struct gpt_params {
int32_t seed = -1; // RNG seed int32_t seed = -1; // RNG seed
int32_t verbose = 0; // Logging verbosity
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 128; // new tokens to predict int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
@ -27,7 +28,7 @@ struct gpt_params {
// sampling parameters // sampling parameters
int32_t top_k = 40; int32_t top_k = 40;
float top_p = 0.95f; float top_p = 0.70f;
float temp = 0.80f; float temp = 0.80f;
float repeat_penalty = 1.10f; float repeat_penalty = 1.10f;

View file

@ -451,7 +451,7 @@ struct llama_file_loader {
llama_file_loader(const char * fname, size_t file_idx, llama_load_tensors_map & tensors_map) llama_file_loader(const char * fname, size_t file_idx, llama_load_tensors_map & tensors_map)
: file(fname, "rb") { : file(fname, "rb") {
fprintf(stderr, "llama.cpp: loading model from %s\n", fname); // fprintf(stderr, "llama.cpp: loading model from %s\n", fname);
read_magic(); read_magic();
read_hparams(); read_hparams();
read_vocab(); read_vocab();
@ -561,7 +561,7 @@ struct llama_file_saver {
llama_file_loader * any_file_loader; llama_file_loader * any_file_loader;
llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype) llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype)
: file(fname, "wb"), any_file_loader(any_file_loader) { : file(fname, "wb"), any_file_loader(any_file_loader) {
fprintf(stderr, "llama.cpp: saving model to %s\n", fname); // fprintf(stderr, "llama.cpp: saving model to %s\n", fname);
write_magic(); write_magic();
write_hparams(new_ftype); write_hparams(new_ftype);
write_vocab(); write_vocab();
@ -919,7 +919,8 @@ static void llama_model_load_internal(
bool use_mlock, bool use_mlock,
bool vocab_only, bool vocab_only,
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void * progress_callback_user_data) { void * progress_callback_user_data,
int verbose) {
lctx.t_start_us = ggml_time_us(); lctx.t_start_us = ggml_time_us();
@ -943,7 +944,7 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx; hparams.n_ctx = n_ctx;
} }
{ if (verbose) {
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
@ -966,7 +967,9 @@ static void llama_model_load_internal(
size_t ctx_size, mmapped_size; size_t ctx_size, mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size); ml->calc_sizes(&ctx_size, &mmapped_size);
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); if (verbose) {
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
}
// print memory requirements // print memory requirements
{ {
@ -984,8 +987,10 @@ static void llama_model_load_internal(
const size_t mem_required_state = const size_t mem_required_state =
scale*MEM_REQ_KV_SELF().at(model.type); scale*MEM_REQ_KV_SELF().at(model.type);
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, if (verbose) {
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
}
} }
// create the ggml context // create the ggml context
@ -1068,10 +1073,12 @@ static bool llama_model_load(
bool use_mlock, bool use_mlock,
bool vocab_only, bool vocab_only,
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void *progress_callback_user_data) { void *progress_callback_user_data,
int verbose) {
// try { // try {
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
vocab_only, progress_callback, progress_callback_user_data); vocab_only, progress_callback, progress_callback_user_data,
verbose);
return true; return true;
// } catch (const std::string & err) { // } catch (const std::string & err) {
// fprintf(stderr, "error loading model: %s\n", err.c_str()); // fprintf(stderr, "error loading model: %s\n", err.c_str());
@ -1783,7 +1790,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
struct llama_context * llama_init_from_file( struct llama_context * llama_init_from_file(
const char * path_model, const char * path_model,
struct llama_context_params params) { struct llama_context_params params,
int verbose) {
ggml_time_init(); ggml_time_init();
llama_context * ctx = new llama_context; llama_context * ctx = new llama_context;
@ -1793,7 +1801,7 @@ struct llama_context * llama_init_from_file(
} }
unsigned cur_percentage = 0; unsigned cur_percentage = 0;
if (params.progress_callback == NULL) { if (verbose && params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage; params.progress_callback_user_data = &cur_percentage;
params.progress_callback = [](float progress, void * ctx) { params.progress_callback = [](float progress, void * ctx) {
unsigned * cur_percentage_p = (unsigned *) ctx; unsigned * cur_percentage_p = (unsigned *) ctx;
@ -1816,7 +1824,8 @@ struct llama_context * llama_init_from_file(
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
params.use_mmap, params.use_mlock, params.vocab_only, params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) { params.progress_callback, params.progress_callback_user_data,
verbose)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
@ -1830,7 +1839,7 @@ struct llama_context * llama_init_from_file(
return nullptr; return nullptr;
} }
{ if (verbose) {
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v); const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
} }

View file

@ -87,7 +87,8 @@ extern "C" {
// Return NULL on failure // Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file( LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model, const char * path_model,
struct llama_context_params params); struct llama_context_params params,
int verbose);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);

View file

@ -96,8 +96,7 @@ void sigint_handler(int signo) {
if (!is_interacting) { if (!is_interacting) {
is_interacting=true; is_interacting=true;
} else { } else {
llama_print_timings(*g_ctx); _exit(128 + signo);
_exit(130);
} }
} }
} }
@ -155,7 +154,9 @@ int main(int argc, char ** argv) {
params.seed = time(NULL); params.seed = time(NULL);
} }
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); if (params.verbose) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
}
std::mt19937 rng(params.seed); std::mt19937 rng(params.seed);
if (params.random_prompt) { if (params.random_prompt) {
@ -179,7 +180,7 @@ int main(int argc, char ** argv) {
lparams.use_mmap = params.use_mmap; lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock; lparams.use_mlock = params.use_mlock;
ctx = llama_init_from_file(params.model.c_str(), lparams); ctx = llama_init_from_file(params.model.c_str(), lparams, params.verbose);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
@ -199,7 +200,7 @@ int main(int argc, char ** argv) {
} }
// print system information // print system information
{ if (params.verbose) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
@ -218,7 +219,9 @@ int main(int argc, char ** argv) {
llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
} }
llama_print_timings(ctx); if (params.verbose) {
llama_print_timings(ctx);
}
llama_free(ctx); llama_free(ctx);
return 0; return 0;
@ -252,8 +255,8 @@ int main(int argc, char ** argv) {
params.antiprompt.push_back("### Instruction:\n\n"); params.antiprompt.push_back("### Instruction:\n\n");
} }
// enable interactive mode if reverse prompt or interactive start is specified // enable interactive mode if interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_first) { if (params.interactive_first) {
params.interactive = true; params.interactive = true;
} }
@ -288,28 +291,33 @@ int main(int argc, char ** argv) {
signal(SIGINT, sigint_handler); signal(SIGINT, sigint_handler);
#endif #endif
fprintf(stderr, "%s: interactive mode on.\n", __func__); if (params.verbose) {
fprintf(stderr, "%s: interactive mode on.\n", __func__);
}
if (params.antiprompt.size()) { if (params.verbose && params.antiprompt.size()) {
for (auto antiprompt : params.antiprompt) { for (auto antiprompt : params.antiprompt) {
fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
} }
} }
if (!params.input_prefix.empty()) { if (params.verbose && !params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
} }
} }
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); if (params.verbose) {
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
fprintf(stderr, "\n\n"); params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
}
// TODO: replace with ring-buffer // TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx); std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) { if (params.verbose && params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
@ -320,7 +328,7 @@ int main(int argc, char ** argv) {
} }
bool is_antiprompt = false; bool is_antiprompt = false;
bool input_noecho = false; bool input_noecho = !params.verbose;
int n_past = 0; int n_past = 0;
int n_remain = params.n_predict; int n_remain = params.n_predict;
@ -427,6 +435,40 @@ int main(int argc, char ** argv) {
} }
} }
// checks for reverse prompt
//
// 1. in interactive mode, this lets us detect when the llm is
// prompting the user, so we can pause for input, e.g.
//
// --interactive
// --prompt $'CompanionAI: How can I help you?\nHuman:'
// --reverse-prompt 'Human:'
//
// 2. in normal mode, the reverse prompt can be used to specify
// a custom EOS token, e.g.
//
// --prompt 'Question: How old are you?\nAnswer: '
// --reverse-prompt $'\n'
//
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_antiprompt = true;
break;
}
}
if (is_antiprompt && !params.interactive) {
printf("\n");
break;
}
}
// display text // display text
if (!input_noecho) { if (!input_noecho) {
for (auto id : embd) { for (auto id : embd) {
@ -435,34 +477,20 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
} }
// reset color to default if we there is no pending user input // reset color to default if we there is no pending user input
if (!input_noecho && (int)embd_inp.size() == n_consumed) { if (params.verbose && !input_noecho && (int)embd_inp.size() == n_consumed) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
} }
if (is_antiprompt) {
is_interacting = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
}
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) { if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
is_antiprompt = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
break;
}
}
}
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input // potentially set color to indicate we are taking user input
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
@ -542,7 +570,7 @@ int main(int argc, char ** argv) {
if (!embd.empty() && embd.back() == llama_token_eos()) { if (!embd.empty() && embd.back() == llama_token_eos()) {
if (params.instruct) { if (params.instruct) {
is_interacting = true; is_interacting = true;
} else { } else if (params.verbose) {
fprintf(stderr, " [end of text]\n"); fprintf(stderr, " [end of text]\n");
break; break;
} }
@ -559,7 +587,9 @@ int main(int argc, char ** argv) {
signal(SIGINT, SIG_DFL); signal(SIGINT, SIG_DFL);
#endif #endif
llama_print_timings(ctx); if (params.verbose) {
llama_print_timings(ctx);
}
llama_free(ctx); llama_free(ctx);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); set_console_color(con_st, CONSOLE_COLOR_DEFAULT);