feat: add "stop" keywords as alternative to eot token
This commit is contained in:
parent
e6a46b0ed1
commit
72f102a4ae
3 changed files with 54 additions and 5 deletions
|
@ -283,6 +283,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.antiprompt.push_back(argv[i]);
|
||||
} else if (arg == "--stop") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.stop_keywords.push_back(argv[i]);
|
||||
} else if (arg == "--perplexity") {
|
||||
params.perplexity = true;
|
||||
} else if (arg == "--ignore-eos") {
|
||||
|
@ -359,8 +365,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
|
||||
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
|
||||
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
|
||||
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
|
||||
fprintf(stderr, " specified more than once for multiple prompts).\n");
|
||||
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT");
|
||||
fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n");
|
||||
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
|
||||
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
|
||||
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
|
||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
||||
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
|
|
|
@ -50,6 +50,7 @@ struct gpt_params {
|
|||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
std::string input_suffix = ""; // string to suffix user inputs with
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop
|
||||
|
||||
std::string lora_adapter = ""; // lora adapter path
|
||||
std::string lora_base = ""; // base model path for the lora adapter
|
||||
|
|
|
@ -264,6 +264,13 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (params.stop_keywords.size()) {
|
||||
for (auto stop_keyword : params.stop_keywords) {
|
||||
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
||||
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
|
@ -513,13 +520,28 @@ int main(int argc, char ** argv) {
|
|||
// check if we should prompt the user for more
|
||||
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
|
||||
|
||||
// check for reverse prompt
|
||||
if (params.antiprompt.size()) {
|
||||
std::string last_output;
|
||||
std::string last_output;
|
||||
if (params.antiprompt.size() || params.stop_keywords.size()) {
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += llama_token_to_str(ctx, id);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for stop keywords, a configurable alternative to the end-of-text token
|
||||
// This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
|
||||
bool stop = false;
|
||||
for (std::string stop_keyword : params.stop_keywords) {
|
||||
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (stop) {
|
||||
break;
|
||||
}
|
||||
|
||||
// check for reverse prompt
|
||||
if (params.antiprompt.size()) {
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
|
@ -586,6 +608,24 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// Check for stop keywords, a configurable alternative to the end-of-text token
|
||||
if (!params.interactive && params.stop_keywords.size() && !is_interacting) {
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += llama_token_to_str(ctx, id);
|
||||
}
|
||||
bool stop = false;
|
||||
for (std::string stop_keyword : params.stop_keywords) {
|
||||
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// end of text token
|
||||
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
||||
if (params.instruct) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue