Clean up llama.com anti/stop/reverse-prompt code

Example use case for JSON completion:

    $ m=opt
    $ make -j16 m=$m o/$m/third_party/ggml/llama.com
    $ o/$m/third_party/ggml/llama.com -m llama.bin -p '{"key": "life", "val": ' -r '}'
    42}

This provides better control. More sophisticated facilities for
controlling text generation will be provided soon enough.
This commit is contained in:
Justine Tunney 2023-05-12 08:20:58 -07:00
parent bbfe4fbd11
commit 80c174d494
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
3 changed files with 100 additions and 104 deletions

View file

@ -268,7 +268,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mem_test = true; params.mem_test = true;
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--stop" || arg == "--reverse-prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -373,10 +373,9 @@ void gpt_print_usage(FILE *f, int /*argc*/, char ** argv, const gpt_params & par
fprintf(f, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(f, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(f, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(f, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(f, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(f, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(f, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(f, " -r PROMPT, --stop PROMPT, --reverse-prompt PROMPT\n");
fprintf(f, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(f, " stop generating text when the specified text is encountered.\n");
fprintf(f, " specified more than once for multiple prompts).\n"); fprintf(f, " this option may be repeated.\n");
fprintf(f, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(f, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(f, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(f, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(f, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(f, " -p PROMPT, --prompt PROMPT\n"); fprintf(f, " -p PROMPT, --prompt PROMPT\n");

View file

@ -58,10 +58,8 @@ $(THIRD_PARTY_GGML_A_OBJS): private \
-mfma -mfma
endif endif
o/rel/third_party/ggml/ggml.o \
o/opt/third_party/ggml/ggml.o: private \ o/opt/third_party/ggml/ggml.o: private \
OVERRIDE_CFLAGS += \ OVERRIDE_CFLAGS += \
-fomit-frame-pointer \
-x-no-pg -x-no-pg
ifeq ($(ARCH), x86_64) ifeq ($(ARCH), x86_64)

View file

@ -47,17 +47,23 @@
#include "third_party/libcxx/string" #include "third_party/libcxx/string"
#include "third_party/libcxx/vector" #include "third_party/libcxx/vector"
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
asm(".ident\t\"\\n\\n\ asm(".ident\t\"\\n\\n\
llama.cpp (MIT License)\\n\ llama.cpp (MIT License)\\n\
Copyright (c) 2023 Georgi Gerganov\""); Copyright (c) 2023 Georgi Gerganov\"");
asm(".include \"libc/disclaimer.inc\""); asm(".include \"libc/disclaimer.inc\"");
// clang-format off // clang-format off
static gpt_params params;
static llama_context * ctx;
static console_state con_st;
////////////////////////////////////////////////////////////////////////////////
static std::atomic<bool> is_interacting; static std::atomic<bool> is_interacting;
static std::atomic<bool> is_terminated; static std::atomic<bool> is_terminated;
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
static void sigint_handler_batch(int signo) { static void sigint_handler_batch(int signo) {
is_terminated = true; is_terminated = true;
} }
@ -78,6 +84,80 @@ static int CompareTime(struct timespec a, struct timespec b) {
return cmp; return cmp;
} }
////////////////////////////////////////////////////////////////////////////////
enum jtlp_status {
kPromptPending,
kPromptCompleted,
kPromptFinished
};
struct jtlp_header {
uint8_t magic[4];
uint8_t version[4];
uint8_t state_size[8];
uint8_t model_dev[8];
uint8_t model_ino[8];
uint8_t model_mtim_sec[8];
uint8_t model_mtim_nsec[8];
uint8_t prompt_size[8];
};
constexpr uint32_t kJtlpMagic = 'j' | 't' << 8 | 'l' << 16 | 'p' << 24;
constexpr uint32_t kJtlpVersion = 0;
static std::string last_output;
static std::vector<llama_token> last_n_tokens;
static std::string::size_type longest_antiprompt;
static enum jtlp_status prompt_status = kPromptPending;
static void remember_init() {
last_output.clear();
last_n_tokens.resize(llama_n_ctx(ctx), 0);
for (std::string & antiprompt : params.antiprompt) {
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
}
}
static void remember_token(llama_token tok) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(tok);
last_output.append(llama_token_to_str(ctx, tok));
if (last_output.size() > longest_antiprompt) {
last_output.erase(0, last_output.size() - longest_antiprompt);
}
}
static bool has_antiprompt(std::string::size_type *out_index = nullptr,
std::string *out_antiprompt = nullptr) {
for (std::string & antiprompt : params.antiprompt) {
std::string::size_type index = last_output.rfind(antiprompt);
if (index != std::string::npos) {
if (out_index) *out_index = index;
if (out_antiprompt) *out_antiprompt = antiprompt;
return true;
}
}
return false;
}
static void finish_initializing_prompt() {
prompt_status = kPromptFinished;
if (params.interactive) {
std::string::size_type pos;
is_interacting = true;
if (has_antiprompt(&pos)) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", last_output.substr(pos).c_str());
last_output.clear();
fflush(stdout);
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
}
////////////////////////////////////////////////////////////////////////////////
static int on_missing_feature(const char *name) { static int on_missing_feature(const char *name) {
fprintf(stderr, "%s: error: cpuid %s not detected\n", __func__, name); fprintf(stderr, "%s: error: cpuid %s not detected\n", __func__, name);
fprintf(stderr, "%s: amd microprocessors made after 2017 usually work\n", __func__); fprintf(stderr, "%s: amd microprocessors made after 2017 usually work\n", __func__);
@ -86,7 +166,6 @@ static int on_missing_feature(const char *name) {
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params;
ShowCrashReports(); ShowCrashReports();
setvbuf(stdin, NULL, _IONBF, 0); setvbuf(stdin, NULL, _IONBF, 0);
@ -111,7 +190,6 @@ int main(int argc, char ** argv) {
// save choice to use color for later // save choice to use color for later
// (note for later: this is a slightly awkward choice) // (note for later: this is a slightly awkward choice)
static console_state con_st;
con_st.use_color = params.use_color; con_st.use_color = params.use_color;
con_st.multiline_input = params.multiline_input; con_st.multiline_input = params.multiline_input;
@ -155,7 +233,6 @@ int main(int argc, char ** argv) {
// params.prompt = R"(// this function checks if the number n is prime // params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)"; //bool is_prime(int n) {)";
llama_context * ctx;
struct stat model_stat; struct stat model_stat;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
@ -309,10 +386,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n"); fprintf(stderr, "\n\n");
} }
// TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.verbose && params.interactive) { if (params.verbose && params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(stderr, "== Running in interactive mode. ==\n"
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
@ -321,27 +394,7 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first; is_interacting = params.interactive_first;
} }
const uint32_t kJtlpMagic = READ32LE("jtlp"); remember_init();
const uint32_t kJtlpVersion = 0;
struct jtlp_header {
uint8_t magic[4];
uint8_t version[4];
uint8_t state_size[8];
uint8_t model_dev[8];
uint8_t model_ino[8];
uint8_t model_mtim_sec[8];
uint8_t model_mtim_nsec[8];
uint8_t prompt_size[8];
};
enum jtlp_status {
kPromptPending,
kPromptCompleted,
kPromptFinished
};
enum jtlp_status prompt_status = kPromptPending;
bool is_antiprompt = false; bool is_antiprompt = false;
bool input_noecho = !params.verbose; bool input_noecho = !params.verbose;
@ -442,27 +495,10 @@ int main(int argc, char ** argv) {
// now setup the business logic // now setup the business logic
llama_set_rng_seed(ctx, params.seed); llama_set_rng_seed(ctx, params.seed);
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
last_n_tokens.erase(last_n_tokens.begin()); remember_token(embd_inp[n_consumed++]);
last_n_tokens.push_back(embd_inp[n_consumed++]);
} }
n_past = n_consumed; n_past = n_consumed;
prompt_status = kPromptFinished; finish_initializing_prompt();
if (params.interactive) {
is_interacting = true;
for (std::string & antiprompt : params.antiprompt) {
auto toks = ::llama_tokenize(ctx, antiprompt, false);
if (std::equal(last_n_tokens.end() - toks.size(),
last_n_tokens.end(),
toks.begin(),
toks.end())) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
CantReloadPrompt: CantReloadPrompt:
if (map != MAP_FAILED) { if (map != MAP_FAILED) {
munmap(map, file_size); munmap(map, file_size);
@ -516,8 +552,8 @@ int main(int argc, char ** argv) {
} }
// save prompt to disk atomically as soon as it's finished loading // save prompt to disk atomically as soon as it's finished loading
bool was_completed = prompt_status == kPromptCompleted; bool just_finished_initializing_prompt = prompt_status == kPromptCompleted;
if (was_completed && !params.prompt_path.empty()) { if (just_finished_initializing_prompt && !params.prompt_path.empty()) {
int fd = -1; int fd = -1;
int close_rc; int close_rc;
uint8_t buf[8]; uint8_t buf[8];
@ -588,33 +624,11 @@ int main(int argc, char ** argv) {
if (fd != -1) close(fd); if (fd != -1) close(fd);
if (!tmppath.empty()) unlink(tmppath.c_str()); if (!tmppath.empty()) unlink(tmppath.c_str());
} }
if (was_completed) { if (just_finished_initializing_prompt) {
if (!params.verbose && con_st.use_color) { if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL("")); fprintf(stderr, EPHEMERAL(""));
} }
if (params.interactive) { finish_initializing_prompt();
is_interacting = true;
}
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
fflush(stdout);
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(),
last_output.length() - antiprompt.length(),
antiprompt.length()) != std::string::npos) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
} }
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
@ -689,8 +703,7 @@ int main(int argc, char ** argv) {
} }
} }
last_n_tokens.erase(last_n_tokens.begin()); remember_token(id);
last_n_tokens.push_back(id);
} }
// replace end of text token with newline token when in interactive mode // replace end of text token with newline token when in interactive mode
@ -716,8 +729,7 @@ int main(int argc, char ** argv) {
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); remember_token(embd_inp[n_consumed++]);
last_n_tokens.push_back(embd_inp[n_consumed++]);
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
} }
@ -745,24 +757,7 @@ int main(int argc, char ** argv) {
// --prompt 'Question: How old are you?\nAnswer: ' // --prompt 'Question: How old are you?\nAnswer: '
// --reverse-prompt $'\n' // --reverse-prompt $'\n'
// //
if (params.antiprompt.size()) { is_antiprompt = has_antiprompt();
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_antiprompt = true;
break;
}
}
if (is_antiprompt && !params.interactive) {
printf("\n");
break;
}
}
// display text // display text
if (!input_noecho) { if (!input_noecho) {
@ -771,6 +766,10 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
} }
if (is_antiprompt && !params.interactive) {
printf("\n");
break;
}
if (prompt_status == kPromptCompleted) { if (prompt_status == kPromptCompleted) {
continue; // avoid reading line before last token loads continue; // avoid reading line before last token loads
} }