Check for reverse prompt by characters instead of tokens (#292)

This commit is contained in:
Johnman 2023-03-20 16:06:58 +01:00
parent 074bea2eb1
commit b646ffa1b1

View file

@ -12,6 +12,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h> #include <signal.h>
@ -877,16 +878,9 @@ int main(int argc, char ** argv) {
params.interactive = true; params.interactive = true;
params.antiprompt.push_back("### Instruction:\n\n"); params.antiprompt.push_back("### Instruction:\n\n");
} }
// tokenize the reverse prompt
std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
for (auto antiprompt : params.antiprompt) {
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
}
// enable interactive mode if reverse prompt is specified // enable interactive mode if reverse prompt is specified
if (antipromptv_inp.size() != 0) { if (params.antiprompt.size() != 0) {
params.interactive = true; params.interactive = true;
} }
@ -910,15 +904,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: interactive mode on.\n", __func__); fprintf(stderr, "%s: interactive mode on.\n", __func__);
if(antipromptv_inp.size()) { if(params.antiprompt.size()) {
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { for (auto antiprompt : params.antiprompt) {
auto antiprompt_inp = antipromptv_inp.at(apindex); fprintf(stderr, "Antiprompt: %s\n", antiprompt);
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
fprintf(stderr, "\n");
} }
} }
} }
@ -1035,12 +1023,23 @@ int main(int argc, char ** argv) {
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) { if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt // check for reverse prompt
for (auto antiprompt_inp : antipromptv_inp) {
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { std::stringstream last_output_ss;
// reverse prompt found for (auto id : last_n_tokens) {
last_output_ss << vocab.id_to_token[id];
}
std::string last_output = last_output_ss.str();
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true; is_interacting = true;
break; break;
} }
/*if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
// reverse prompt found
is_interacting = true;
break;
}*/
} }
if (is_interacting) { if (is_interacting) {
if (params.instruct) { if (params.instruct) {