From 72f102a4aeb60ae7364d8ec1816e566ebbac0cc1 Mon Sep 17 00:00:00 2001 From: Claude Doppler Date: Tue, 4 Apr 2023 20:33:09 +0000 Subject: [PATCH] feat: add "stop" keywords as alternative to eot token --- examples/common.cpp | 12 +++++++++-- examples/common.h | 1 + examples/main/main.cpp | 46 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 7aa77587b..edb27b4c1 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -283,6 +283,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.antiprompt.push_back(argv[i]); + } else if (arg == "--stop") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.stop_keywords.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { @@ -359,8 +365,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); - fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT"); + fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); + fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); + fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/examples/common.h b/examples/common.h index 43f1cc9ef..1ca16dcb6 100644 --- a/examples/common.h +++ b/examples/common.h @@ -50,6 +50,7 @@ struct gpt_params { std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + std::vector stop_keywords; // string upon seeing which the model will stop std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e1172a48..713edadb6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str()); } } + + if (params.stop_keywords.size()) { + for (auto stop_keyword : params.stop_keywords) { + fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); + } + } + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); @@ -513,13 +520,28 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) { - // check for reverse prompt - if (params.antiprompt.size()) { - std::string last_output; + std::string last_output; + if (params.antiprompt.size() || params.stop_keywords.size()) { for (auto id : last_n_tokens) { last_output += llama_token_to_str(ctx, id); } + } + // Check for stop keywords, a configurable alternative to the end-of-text token + // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + + // check for reverse prompt + if (params.antiprompt.size()) { is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & antiprompt : params.antiprompt) { @@ -586,6 +608,24 @@ int main(int argc, char ** argv) { } } + // Check for stop keywords, a configurable alternative to the end-of-text token + if (!params.interactive && params.stop_keywords.size() && !is_interacting) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + } + // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { if (params.instruct) {