From 660a4d5365f1b89ee82827bfe6e4201996506d54 Mon Sep 17 00:00:00 2001 From: Thomas Antony Date: Fri, 17 Mar 2023 19:03:20 -0700 Subject: [PATCH] Refactor interactive mode in main.cpp --- main.cpp | 155 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 89 insertions(+), 66 deletions(-) diff --git a/main.cpp b/main.cpp index 0165ce348..f9d6f894a 100644 --- a/main.cpp +++ b/main.cpp @@ -28,7 +28,6 @@ #define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_BOLD "\x1b[1m" -static const int EOS_TOKEN_ID = 2; // determine number of model parts based on the dimension static const std::map LLAMA_N_PARTS = { @@ -56,6 +55,8 @@ void sigint_handler(int signo) { #endif +void process_interactive_input(llama_context& ctx, const gpt_params& params); + int main(int argc, char ** argv) { ggml_time_init(); const int64_t t_main_start_us = ggml_time_us(); @@ -86,15 +87,18 @@ int main(int argc, char ** argv) { // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; - int64_t t_load_us = 0; - // load the model - llama_context* ctx_ptr = llama_init_from_params(params); + llama_context* ctx_ptr = nullptr; + { + ctx_ptr = llama_init_from_params(params); + if (!ctx_ptr) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + } + llama_context & ctx = *ctx_ptr; - gpt_vocab & vocab = llama_context_get_vocab(ctx); - - // print system information - llama_print_context_info(ctx); + const gpt_vocab & vocab = llama_context_get_vocab(ctx); // Add a space in front of the first character to match OG llama tokenizer behavior params.prompt.insert(0, 1, ' '); @@ -110,8 +114,9 @@ int main(int argc, char ** argv) { } // tokenize the reverse prompt - std::vector antiprompt_inp = llama_tokenize_text(ctx, params.prompt); + std::vector antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt); + // Setup interactive mode if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -150,43 +155,43 @@ int main(int argc, char ** argv) { is_interacting = true; } - bool input_noecho = false; - - int remaining_tokens = params.n_predict; - // set the color for the prompt which will be output initially if (params.use_color) { printf(ANSI_COLOR_YELLOW); } - if(!llama_ingest_input(ctx, params.prompt)) + // Prepare the context with input + // Send "beginning of string" + llama_add_bos(ctx); + + // load the input + llama_update_input(ctx, params.prompt); + + llama_print_startup_stats(ctx); + + if(!llama_prepare_context(ctx)) { - fprintf(stderr, "Failed to ingest prompt\n"); + fprintf(stderr, "%s: failed to prepare context\n", __func__); return 1; - }; - - // display text - input_noecho = false; - const std::vector& embd = llama_context_get_embedding(ctx); - if (!input_noecho) { - for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); - } - fflush(stdout); } - if (!input_noecho && params.use_color) { - printf(ANSI_COLOR_RESET); - } + bool input_noecho = false; + bool is_end_of_text = false; + while (llama_context_is_finished(ctx) == false) { + std::string model_output{}; - const std::vector& last_n_tokens = llama_context_get_last_n_tokens(ctx); - - while (llama_context_is_finished(ctx) != true) { - gpt_vocab::id model_output = 0; - bool response = llama_infer(ctx, model_output); - if (response) { - printf("%s", vocab.id_to_token[model_output].c_str()); - fflush(stdout); + if (llama_has_unconsumed_input(ctx)) { + llama_ingest_all_pending_input(ctx, !input_noecho); + // reset color to default if we there is no pending user input + if (!input_noecho && params.use_color) { + printf(ANSI_COLOR_RESET); + } + }else{ + // Run inference if we don't have any pending input + llama_infer(ctx, model_output, is_end_of_text); + // print the single token output + printf("%s", model_output.c_str()); + input_noecho = false; } // reset color to default if we there is no pending user input if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) { @@ -195,10 +200,9 @@ int main(int argc, char ** argv) { // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more - if (params.interactive) { - // check for reverse prompt - for (auto antiprompt_inp : antipromptv_inp) { - if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { + if (params.interactive && !llama_has_unconsumed_input(ctx)) { + // check for reverse prompt + if (antiprompt_inp.size() && llama_is_anti_prompt_present(ctx, antiprompt_inp)) { // reverse prompt found is_interacting = true; break; @@ -206,37 +210,21 @@ int main(int argc, char ** argv) { } if (is_interacting) { if (params.instruct) { - input_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + llama_update_input(ctx, "\n\n### Instruction:\n\n"); printf("\n> "); } // currently being interactive - if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); - std::string buffer; - std::string line; - bool another_line = true; - do { - std::getline(std::cin, line); - if (line.empty() || line.back() != '\\') { - another_line = false; - } else { - line.pop_back(); // Remove the continue character - } - // Do not clear existing context in interactive mode - llama_update_context_with_prompt(ctx, buf, false); - } - - remaining_tokens -= line_inp.size(); - - input_noecho = true; // do not echo this again + process_interactive_input(ctx, params); + input_noecho = true; // do not echo this input again + is_interacting = false; } is_interacting = false; } // end of text token - if (embd.back() == EOS_TOKEN_ID) { + if (is_end_of_text) { if (params.interactive) { is_interacting = true; } else { @@ -246,23 +234,58 @@ int main(int argc, char ** argv) { } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. - if (params.interactive && remaining_tokens <= 0) { - remaining_tokens = params.n_predict; + if (params.interactive && llama_context_is_finished(ctx)) { + llama_reset_remaining_tokens(ctx); is_interacting = true; } } - // report timing from context + +#if defined (_WIN32) + signal(SIGINT, SIG_DFL); +#endif + + // report timing { const int64_t t_main_end_us = ggml_time_us(); llama_print_end_stats(ctx); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } - llama_free_context(ctx_ptr); + + llama_free_context(ctx_ptr); if (params.use_color) { printf(ANSI_COLOR_RESET); } - return 0; } + +void process_interactive_input(llama_context& ctx, const gpt_params& params) +{ + bool another_line = true; + while (another_line) { + fflush(stdout); + char buf[256] = {0}; + int n_read; + if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); + if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) { + // presumable empty line, consume the newline + std::ignore = scanf("%*c"); + n_read=0; + } + if (params.use_color) printf(ANSI_COLOR_RESET); + + if (n_read > 0 && buf[n_read-1]=='\\') { + another_line = true; + buf[n_read-1] = '\n'; + buf[n_read] = 0; + } else { + another_line = false; + buf[n_read] = '\n'; + buf[n_read+1] = 0; + } + + // Do not clear existing context in interactive mode + llama_update_input(ctx, buf); + } +}