Initial implementation of stop keywords
This commit is contained in:
parent
16ffc013c6
commit
3eed8c0914
3 changed files with 70 additions and 43 deletions
32
main.cpp
32
main.cpp
|
@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(params.stop_keyword.size()) {
|
||||
for (auto stop_keyword : params.stop_keyword) {
|
||||
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
|
@ -1129,15 +1136,31 @@ int main(int argc, char ** argv) {
|
|||
printf(ANSI_COLOR_RESET);
|
||||
}
|
||||
|
||||
// in interactive mode, and not currently processing queued inputs;
|
||||
// check if we should prompt the user for more
|
||||
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
|
||||
// check for reverse prompt
|
||||
// If we are not processing queued inputs, check for reverse prompt and stop keywords
|
||||
if((int) embd_inp.size() <= input_consumed) {
|
||||
// Build the output string
|
||||
// TODO - Recomputing this whole string every iteration is not efficient
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += vocab.id_to_token[id];
|
||||
}
|
||||
|
||||
// Check for stop keywords
|
||||
bool stop = false;
|
||||
for (std::string stop_keyword : params.stop_keyword) {
|
||||
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(stop) {
|
||||
break;
|
||||
}
|
||||
|
||||
// in interactive mode, and not currently processing queued inputs;
|
||||
// check if we should prompt the user for more
|
||||
if (params.interactive) {
|
||||
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
for (std::string antiprompt : params.antiprompt) {
|
||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||
|
@ -1182,6 +1205,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
}
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == EOS_TOKEN_ID) {
|
||||
|
|
|
@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.use_color = true;
|
||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||
params.antiprompt.push_back(argv[++i]);
|
||||
} else if (arg == "--stop") {
|
||||
params.stop_keyword.push_back(argv[++i]);
|
||||
} else if (arg == "--perplexity") {
|
||||
params.perplexity = true;
|
||||
} else if (arg == "--ignore-eos") {
|
||||
|
|
1
utils.h
1
utils.h
|
@ -32,6 +32,7 @@ struct gpt_params {
|
|||
std::string prompt = "";
|
||||
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop
|
||||
|
||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue