Initial implementation of stop keywords

This commit is contained in:
Joshua Williams 2023-03-21 12:26:36 -05:00
parent 16ffc013c6
commit 3eed8c0914
3 changed files with 70 additions and 43 deletions

View file

@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) {
}
}
}
if(params.stop_keyword.size()) {
for (auto stop_keyword : params.stop_keyword) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "\n\n");
@ -1129,15 +1136,31 @@ int main(int argc, char ** argv) {
printf(ANSI_COLOR_RESET);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt
// If we are not processing queued inputs, check for reverse prompt and stop keywords
if((int) embd_inp.size() <= input_consumed) {
// Build the output string
// TODO - Recomputing this whole string every iteration is not efficient
std::string last_output;
for (auto id : last_n_tokens) {
last_output += vocab.id_to_token[id];
}
// Check for stop keywords
bool stop = false;
for (std::string stop_keyword : params.stop_keyword) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if(stop) {
break;
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive) {
// Check if each of the reverse prompts appears at the end of the output.
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
@ -1182,6 +1205,7 @@ int main(int argc, char ** argv) {
}
is_interacting = false;
}
}
// end of text token
if (embd.back() == EOS_TOKEN_ID) {

View file

@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--stop") {
params.stop_keyword.push_back(argv[++i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {

View file

@ -32,6 +32,7 @@ struct gpt_params {
std::string prompt = "";
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop
bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided