diff --git a/examples/common.cpp b/examples/common.cpp index 348c60500..d9a1e55b1 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -365,8 +365,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n"); - fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); + fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); diff --git a/examples/common.h b/examples/common.h index 1ca16dcb6..468b01b32 100644 --- a/examples/common.h +++ b/examples/common.h @@ -46,10 +46,10 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string path_session = ""; // path to file for saving/loading model eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted + std::string path_session = ""; // path to file for saving/loading model eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::vector antiprompt; // string upon seeing which more user input is prompted std::vector stop_keywords; // string upon seeing which the model will stop std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 713edadb6..9cc6550fd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { } if (params.stop_keywords.size()) { - for (auto stop_keyword : params.stop_keywords) { + for (auto & stop_keyword : params.stop_keywords) { fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); } } @@ -516,22 +516,17 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } - // in interactive mode, and not currently processing queued inputs; - // check if we should prompt the user for more - if (params.interactive && (int) embd_inp.size() <= n_consumed) { - + // check for stop keywords if we're processing generations + if (params.stop_keywords.size() && (int) embd_inp.size() <= n_consumed) { std::string last_output; - if (params.antiprompt.size() || params.stop_keywords.size()) { - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); } - - // Check for stop keywords, a configurable alternative to the end-of-text token - // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM bool stop = false; - for (std::string stop_keyword : params.stop_keywords) { - if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + for (auto & stop_keyword : params.stop_keywords) { + const size_t stop_pos = last_output.find(stop_keyword.c_str(), + last_output.length() - stop_keyword.length(), stop_keyword.length()); + if (stop_pos != std::string::npos) { stop = true; break; } @@ -539,9 +534,19 @@ int main(int argc, char ** argv) { if (stop) { break; } + } + + // in interactive mode, and not currently processing queued inputs; + // check if we should prompt the user for more + if (params.interactive && (int) embd_inp.size() <= n_consumed) { // check for reverse prompt if (params.antiprompt.size()) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & antiprompt : params.antiprompt) { @@ -608,24 +613,6 @@ int main(int argc, char ** argv) { } } - // Check for stop keywords, a configurable alternative to the end-of-text token - if (!params.interactive && params.stop_keywords.size() && !is_interacting) { - std::string last_output; - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } - bool stop = false; - for (std::string stop_keyword : params.stop_keywords) { - if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { - stop = true; - break; - } - } - if (stop) { - break; - } - } - // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { if (params.instruct) {