Clean up llama.com anti/stop/reverse-prompt code

Example use case for JSON completion:

    $ m=opt
    $ make -j16 m=$m o/$m/third_party/ggml/llama.com
    $ o/$m/third_party/ggml/llama.com -m llama.bin -p '{"key": "life", "val": ' -r '}'
    42}

This provides better control. More sophisticated facilities for
controlling text generation will be provided soon enough.
This commit is contained in:
Justine Tunney 2023-05-12 08:20:58 -07:00
parent bbfe4fbd11
commit 80c174d494
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
3 changed files with 100 additions and 104 deletions

View file

@ -268,7 +268,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mem_test = true;
} else if (arg == "--verbose-prompt") {
params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
} else if (arg == "-r" || arg == "--stop" || arg == "--reverse-prompt") {
if (++i >= argc) {
invalid_param = true;
break;
@ -373,10 +373,9 @@ void gpt_print_usage(FILE *f, int /*argc*/, char ** argv, const gpt_params & par
fprintf(f, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(f, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(f, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(f, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(f, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(f, " specified more than once for multiple prompts).\n");
fprintf(f, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(f, " -r PROMPT, --stop PROMPT, --reverse-prompt PROMPT\n");
fprintf(f, " stop generating text when the specified text is encountered.\n");
fprintf(f, " this option may be repeated.\n");
fprintf(f, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(f, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(f, " -p PROMPT, --prompt PROMPT\n");

View file

@ -58,10 +58,8 @@ $(THIRD_PARTY_GGML_A_OBJS): private \
-mfma
endif
o/rel/third_party/ggml/ggml.o \
o/opt/third_party/ggml/ggml.o: private \
OVERRIDE_CFLAGS += \
-fomit-frame-pointer \
-x-no-pg
ifeq ($(ARCH), x86_64)

View file

@ -47,17 +47,23 @@
#include "third_party/libcxx/string"
#include "third_party/libcxx/vector"
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
asm(".ident\t\"\\n\\n\
llama.cpp (MIT License)\\n\
Copyright (c) 2023 Georgi Gerganov\"");
asm(".include \"libc/disclaimer.inc\"");
// clang-format off
static gpt_params params;
static llama_context * ctx;
static console_state con_st;
////////////////////////////////////////////////////////////////////////////////
static std::atomic<bool> is_interacting;
static std::atomic<bool> is_terminated;
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
static void sigint_handler_batch(int signo) {
is_terminated = true;
}
@ -78,6 +84,80 @@ static int CompareTime(struct timespec a, struct timespec b) {
return cmp;
}
////////////////////////////////////////////////////////////////////////////////
enum jtlp_status {
kPromptPending,
kPromptCompleted,
kPromptFinished
};
struct jtlp_header {
uint8_t magic[4];
uint8_t version[4];
uint8_t state_size[8];
uint8_t model_dev[8];
uint8_t model_ino[8];
uint8_t model_mtim_sec[8];
uint8_t model_mtim_nsec[8];
uint8_t prompt_size[8];
};
constexpr uint32_t kJtlpMagic = 'j' | 't' << 8 | 'l' << 16 | 'p' << 24;
constexpr uint32_t kJtlpVersion = 0;
static std::string last_output;
static std::vector<llama_token> last_n_tokens;
static std::string::size_type longest_antiprompt;
static enum jtlp_status prompt_status = kPromptPending;
static void remember_init() {
last_output.clear();
last_n_tokens.resize(llama_n_ctx(ctx), 0);
for (std::string & antiprompt : params.antiprompt) {
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
}
}
static void remember_token(llama_token tok) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(tok);
last_output.append(llama_token_to_str(ctx, tok));
if (last_output.size() > longest_antiprompt) {
last_output.erase(0, last_output.size() - longest_antiprompt);
}
}
static bool has_antiprompt(std::string::size_type *out_index = nullptr,
std::string *out_antiprompt = nullptr) {
for (std::string & antiprompt : params.antiprompt) {
std::string::size_type index = last_output.rfind(antiprompt);
if (index != std::string::npos) {
if (out_index) *out_index = index;
if (out_antiprompt) *out_antiprompt = antiprompt;
return true;
}
}
return false;
}
static void finish_initializing_prompt() {
prompt_status = kPromptFinished;
if (params.interactive) {
std::string::size_type pos;
is_interacting = true;
if (has_antiprompt(&pos)) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", last_output.substr(pos).c_str());
last_output.clear();
fflush(stdout);
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
}
////////////////////////////////////////////////////////////////////////////////
static int on_missing_feature(const char *name) {
fprintf(stderr, "%s: error: cpuid %s not detected\n", __func__, name);
fprintf(stderr, "%s: amd microprocessors made after 2017 usually work\n", __func__);
@ -86,7 +166,6 @@ static int on_missing_feature(const char *name) {
}
int main(int argc, char ** argv) {
gpt_params params;
ShowCrashReports();
setvbuf(stdin, NULL, _IONBF, 0);
@ -111,7 +190,6 @@ int main(int argc, char ** argv) {
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
static console_state con_st;
con_st.use_color = params.use_color;
con_st.multiline_input = params.multiline_input;
@ -155,7 +233,6 @@ int main(int argc, char ** argv) {
// params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)";
llama_context * ctx;
struct stat model_stat;
// load the model and apply lora adapter, if any
@ -309,10 +386,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.verbose && params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n"
" - Press Ctrl+C to interject at any time.\n"
@ -321,27 +394,7 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first;
}
const uint32_t kJtlpMagic = READ32LE("jtlp");
const uint32_t kJtlpVersion = 0;
struct jtlp_header {
uint8_t magic[4];
uint8_t version[4];
uint8_t state_size[8];
uint8_t model_dev[8];
uint8_t model_ino[8];
uint8_t model_mtim_sec[8];
uint8_t model_mtim_nsec[8];
uint8_t prompt_size[8];
};
enum jtlp_status {
kPromptPending,
kPromptCompleted,
kPromptFinished
};
enum jtlp_status prompt_status = kPromptPending;
remember_init();
bool is_antiprompt = false;
bool input_noecho = !params.verbose;
@ -442,27 +495,10 @@ int main(int argc, char ** argv) {
// now setup the business logic
llama_set_rng_seed(ctx, params.seed);
while ((int) embd_inp.size() > n_consumed) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed++]);
remember_token(embd_inp[n_consumed++]);
}
n_past = n_consumed;
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
for (std::string & antiprompt : params.antiprompt) {
auto toks = ::llama_tokenize(ctx, antiprompt, false);
if (std::equal(last_n_tokens.end() - toks.size(),
last_n_tokens.end(),
toks.begin(),
toks.end())) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
finish_initializing_prompt();
CantReloadPrompt:
if (map != MAP_FAILED) {
munmap(map, file_size);
@ -516,8 +552,8 @@ int main(int argc, char ** argv) {
}
// save prompt to disk atomically as soon as it's finished loading
bool was_completed = prompt_status == kPromptCompleted;
if (was_completed && !params.prompt_path.empty()) {
bool just_finished_initializing_prompt = prompt_status == kPromptCompleted;
if (just_finished_initializing_prompt && !params.prompt_path.empty()) {
int fd = -1;
int close_rc;
uint8_t buf[8];
@ -588,33 +624,11 @@ int main(int argc, char ** argv) {
if (fd != -1) close(fd);
if (!tmppath.empty()) unlink(tmppath.c_str());
}
if (was_completed) {
if (just_finished_initializing_prompt) {
if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL(""));
}
if (params.interactive) {
is_interacting = true;
}
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
fflush(stdout);
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(),
last_output.length() - antiprompt.length(),
antiprompt.length()) != std::string::npos) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
finish_initializing_prompt();
}
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
@ -689,8 +703,7 @@ int main(int argc, char ** argv) {
}
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
remember_token(id);
}
// replace end of text token with newline token when in interactive mode
@ -716,8 +729,7 @@ int main(int argc, char ** argv) {
// some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed++]);
remember_token(embd_inp[n_consumed++]);
if ((int) embd.size() >= params.n_batch) {
break;
}
@ -745,24 +757,7 @@ int main(int argc, char ** argv) {
// --prompt 'Question: How old are you?\nAnswer: '
// --reverse-prompt $'\n'
//
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_antiprompt = true;
break;
}
}
if (is_antiprompt && !params.interactive) {
printf("\n");
break;
}
}
is_antiprompt = has_antiprompt();
// display text
if (!input_noecho) {
@ -771,6 +766,10 @@ int main(int argc, char ** argv) {
}
fflush(stdout);
}
if (is_antiprompt && !params.interactive) {
printf("\n");
break;
}
if (prompt_status == kPromptCompleted) {
continue; // avoid reading line before last token loads
}