Initial implementation of stop keywords

This commit is contained in:
Joshua Williams 2023-03-21 12:26:36 -05:00
parent 16ffc013c6
commit 3eed8c0914
3 changed files with 70 additions and 43 deletions

110
main.cpp
View file

@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) {
}
}
}
if(params.stop_keyword.size()) {
for (auto stop_keyword : params.stop_keyword) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "\n\n");
@ -1129,58 +1136,75 @@ int main(int argc, char ** argv) {
printf(ANSI_COLOR_RESET);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt
// If we are not processing queued inputs, check for reverse prompt and stop keywords
if((int) embd_inp.size() <= input_consumed) {
// Build the output string
// TODO - Recomputing this whole string every iteration is not efficient
std::string last_output;
for (auto id : last_n_tokens) {
last_output += vocab.id_to_token[id];
}
// Check if each of the reverse prompts appears at the end of the output.
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
// Check for stop keywords
bool stop = false;
for (std::string stop_keyword : params.stop_keyword) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (is_interacting) {
if (params.instruct) {
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
printf("\n> ");
}
// currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
if(stop) {
break;
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive) {
// Check if each of the reverse prompts appears at the end of the output.
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
break;
}
}
if (is_interacting) {
if (params.instruct) {
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
printf("\n> ");
}
// currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
}
is_interacting = false;
}
is_interacting = false;
}
// end of text token

View file

@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--stop") {
params.stop_keyword.push_back(argv[++i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {

View file

@ -32,6 +32,7 @@ struct gpt_params {
std::string prompt = "";
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop
bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided