mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-03-15 05:16:30 +00:00
Clean up llama.com anti/stop/reverse-prompt code
Example use case for JSON completion: $ m=opt $ make -j16 m=$m o/$m/third_party/ggml/llama.com $ o/$m/third_party/ggml/llama.com -m llama.bin -p '{"key": "life", "val": ' -r '}' 42} This provides better control. More sophisticated facilities for controlling text generation will be provided soon enough.
This commit is contained in:
parent
bbfe4fbd11
commit
80c174d494
3 changed files with 100 additions and 104 deletions
9
third_party/ggml/common.cc
vendored
9
third_party/ggml/common.cc
vendored
|
@ -268,7 +268,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.mem_test = true;
|
||||
} else if (arg == "--verbose-prompt") {
|
||||
params.verbose_prompt = true;
|
||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||
} else if (arg == "-r" || arg == "--stop" || arg == "--reverse-prompt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
|
@ -373,10 +373,9 @@ void gpt_print_usage(FILE *f, int /*argc*/, char ** argv, const gpt_params & par
|
|||
fprintf(f, " --interactive-first run in interactive mode and wait for input right away\n");
|
||||
fprintf(f, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
|
||||
fprintf(f, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
|
||||
fprintf(f, " -r PROMPT, --reverse-prompt PROMPT\n");
|
||||
fprintf(f, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
|
||||
fprintf(f, " specified more than once for multiple prompts).\n");
|
||||
fprintf(f, " --color colorise output to distinguish prompt and user input from generations\n");
|
||||
fprintf(f, " -r PROMPT, --stop PROMPT, --reverse-prompt PROMPT\n");
|
||||
fprintf(f, " stop generating text when the specified text is encountered.\n");
|
||||
fprintf(f, " this option may be repeated.\n");
|
||||
fprintf(f, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
||||
fprintf(f, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
fprintf(f, " -p PROMPT, --prompt PROMPT\n");
|
||||
|
|
2
third_party/ggml/ggml.mk
vendored
2
third_party/ggml/ggml.mk
vendored
|
@ -58,10 +58,8 @@ $(THIRD_PARTY_GGML_A_OBJS): private \
|
|||
-mfma
|
||||
endif
|
||||
|
||||
o/rel/third_party/ggml/ggml.o \
|
||||
o/opt/third_party/ggml/ggml.o: private \
|
||||
OVERRIDE_CFLAGS += \
|
||||
-fomit-frame-pointer \
|
||||
-x-no-pg
|
||||
|
||||
ifeq ($(ARCH), x86_64)
|
||||
|
|
193
third_party/ggml/main.cc
vendored
193
third_party/ggml/main.cc
vendored
|
@ -47,17 +47,23 @@
|
|||
#include "third_party/libcxx/string"
|
||||
#include "third_party/libcxx/vector"
|
||||
|
||||
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
|
||||
|
||||
asm(".ident\t\"\\n\\n\
|
||||
llama.cpp (MIT License)\\n\
|
||||
Copyright (c) 2023 Georgi Gerganov\"");
|
||||
asm(".include \"libc/disclaimer.inc\"");
|
||||
// clang-format off
|
||||
|
||||
static gpt_params params;
|
||||
static llama_context * ctx;
|
||||
static console_state con_st;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static std::atomic<bool> is_interacting;
|
||||
static std::atomic<bool> is_terminated;
|
||||
|
||||
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
|
||||
|
||||
static void sigint_handler_batch(int signo) {
|
||||
is_terminated = true;
|
||||
}
|
||||
|
@ -78,6 +84,80 @@ static int CompareTime(struct timespec a, struct timespec b) {
|
|||
return cmp;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum jtlp_status {
|
||||
kPromptPending,
|
||||
kPromptCompleted,
|
||||
kPromptFinished
|
||||
};
|
||||
|
||||
struct jtlp_header {
|
||||
uint8_t magic[4];
|
||||
uint8_t version[4];
|
||||
uint8_t state_size[8];
|
||||
uint8_t model_dev[8];
|
||||
uint8_t model_ino[8];
|
||||
uint8_t model_mtim_sec[8];
|
||||
uint8_t model_mtim_nsec[8];
|
||||
uint8_t prompt_size[8];
|
||||
};
|
||||
|
||||
constexpr uint32_t kJtlpMagic = 'j' | 't' << 8 | 'l' << 16 | 'p' << 24;
|
||||
constexpr uint32_t kJtlpVersion = 0;
|
||||
|
||||
static std::string last_output;
|
||||
static std::vector<llama_token> last_n_tokens;
|
||||
static std::string::size_type longest_antiprompt;
|
||||
static enum jtlp_status prompt_status = kPromptPending;
|
||||
|
||||
static void remember_init() {
|
||||
last_output.clear();
|
||||
last_n_tokens.resize(llama_n_ctx(ctx), 0);
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
|
||||
}
|
||||
}
|
||||
|
||||
static void remember_token(llama_token tok) {
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(tok);
|
||||
last_output.append(llama_token_to_str(ctx, tok));
|
||||
if (last_output.size() > longest_antiprompt) {
|
||||
last_output.erase(0, last_output.size() - longest_antiprompt);
|
||||
}
|
||||
}
|
||||
|
||||
static bool has_antiprompt(std::string::size_type *out_index = nullptr,
|
||||
std::string *out_antiprompt = nullptr) {
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
std::string::size_type index = last_output.rfind(antiprompt);
|
||||
if (index != std::string::npos) {
|
||||
if (out_index) *out_index = index;
|
||||
if (out_antiprompt) *out_antiprompt = antiprompt;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void finish_initializing_prompt() {
|
||||
prompt_status = kPromptFinished;
|
||||
if (params.interactive) {
|
||||
std::string::size_type pos;
|
||||
is_interacting = true;
|
||||
if (has_antiprompt(&pos)) {
|
||||
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
printf("%s", last_output.substr(pos).c_str());
|
||||
last_output.clear();
|
||||
fflush(stdout);
|
||||
}
|
||||
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static int on_missing_feature(const char *name) {
|
||||
fprintf(stderr, "%s: error: cpuid %s not detected\n", __func__, name);
|
||||
fprintf(stderr, "%s: amd microprocessors made after 2017 usually work\n", __func__);
|
||||
|
@ -86,7 +166,6 @@ static int on_missing_feature(const char *name) {
|
|||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
ShowCrashReports();
|
||||
setvbuf(stdin, NULL, _IONBF, 0);
|
||||
|
@ -111,7 +190,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// save choice to use color for later
|
||||
// (note for later: this is a slightly awkward choice)
|
||||
static console_state con_st;
|
||||
con_st.use_color = params.use_color;
|
||||
|
||||
con_st.multiline_input = params.multiline_input;
|
||||
|
@ -155,7 +233,6 @@ int main(int argc, char ** argv) {
|
|||
// params.prompt = R"(// this function checks if the number n is prime
|
||||
//bool is_prime(int n) {)";
|
||||
|
||||
llama_context * ctx;
|
||||
struct stat model_stat;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
|
@ -309,10 +386,6 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> last_n_tokens(n_ctx);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
|
||||
if (params.verbose && params.interactive) {
|
||||
fprintf(stderr, "== Running in interactive mode. ==\n"
|
||||
" - Press Ctrl+C to interject at any time.\n"
|
||||
|
@ -321,27 +394,7 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = params.interactive_first;
|
||||
}
|
||||
|
||||
const uint32_t kJtlpMagic = READ32LE("jtlp");
|
||||
const uint32_t kJtlpVersion = 0;
|
||||
|
||||
struct jtlp_header {
|
||||
uint8_t magic[4];
|
||||
uint8_t version[4];
|
||||
uint8_t state_size[8];
|
||||
uint8_t model_dev[8];
|
||||
uint8_t model_ino[8];
|
||||
uint8_t model_mtim_sec[8];
|
||||
uint8_t model_mtim_nsec[8];
|
||||
uint8_t prompt_size[8];
|
||||
};
|
||||
|
||||
enum jtlp_status {
|
||||
kPromptPending,
|
||||
kPromptCompleted,
|
||||
kPromptFinished
|
||||
};
|
||||
|
||||
enum jtlp_status prompt_status = kPromptPending;
|
||||
remember_init();
|
||||
|
||||
bool is_antiprompt = false;
|
||||
bool input_noecho = !params.verbose;
|
||||
|
@ -442,27 +495,10 @@ int main(int argc, char ** argv) {
|
|||
// now setup the business logic
|
||||
llama_set_rng_seed(ctx, params.seed);
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(embd_inp[n_consumed++]);
|
||||
remember_token(embd_inp[n_consumed++]);
|
||||
}
|
||||
n_past = n_consumed;
|
||||
prompt_status = kPromptFinished;
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
auto toks = ::llama_tokenize(ctx, antiprompt, false);
|
||||
if (std::equal(last_n_tokens.end() - toks.size(),
|
||||
last_n_tokens.end(),
|
||||
toks.begin(),
|
||||
toks.end())) {
|
||||
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
printf("%s", antiprompt.c_str());
|
||||
fflush(stdout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
||||
}
|
||||
finish_initializing_prompt();
|
||||
CantReloadPrompt:
|
||||
if (map != MAP_FAILED) {
|
||||
munmap(map, file_size);
|
||||
|
@ -516,8 +552,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// save prompt to disk atomically as soon as it's finished loading
|
||||
bool was_completed = prompt_status == kPromptCompleted;
|
||||
if (was_completed && !params.prompt_path.empty()) {
|
||||
bool just_finished_initializing_prompt = prompt_status == kPromptCompleted;
|
||||
if (just_finished_initializing_prompt && !params.prompt_path.empty()) {
|
||||
int fd = -1;
|
||||
int close_rc;
|
||||
uint8_t buf[8];
|
||||
|
@ -588,33 +624,11 @@ int main(int argc, char ** argv) {
|
|||
if (fd != -1) close(fd);
|
||||
if (!tmppath.empty()) unlink(tmppath.c_str());
|
||||
}
|
||||
if (was_completed) {
|
||||
if (just_finished_initializing_prompt) {
|
||||
if (!params.verbose && con_st.use_color) {
|
||||
fprintf(stderr, EPHEMERAL(""));
|
||||
}
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
prompt_status = kPromptFinished;
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
fflush(stdout);
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += llama_token_to_str(ctx, id);
|
||||
}
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
if (last_output.find(antiprompt.c_str(),
|
||||
last_output.length() - antiprompt.length(),
|
||||
antiprompt.length()) != std::string::npos) {
|
||||
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
printf("%s", antiprompt.c_str());
|
||||
fflush(stdout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
||||
}
|
||||
finish_initializing_prompt();
|
||||
}
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
|
@ -689,8 +703,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(id);
|
||||
remember_token(id);
|
||||
}
|
||||
|
||||
// replace end of text token with newline token when in interactive mode
|
||||
|
@ -716,8 +729,7 @@ int main(int argc, char ** argv) {
|
|||
// some user input remains from prompt or interaction, forward it to processing
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(embd_inp[n_consumed++]);
|
||||
remember_token(embd_inp[n_consumed++]);
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
break;
|
||||
}
|
||||
|
@ -745,24 +757,7 @@ int main(int argc, char ** argv) {
|
|||
// --prompt 'Question: How old are you?\nAnswer: '
|
||||
// --reverse-prompt $'\n'
|
||||
//
|
||||
if (params.antiprompt.size()) {
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += llama_token_to_str(ctx, id);
|
||||
}
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_antiprompt && !params.interactive) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
is_antiprompt = has_antiprompt();
|
||||
|
||||
// display text
|
||||
if (!input_noecho) {
|
||||
|
@ -771,6 +766,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
if (is_antiprompt && !params.interactive) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
if (prompt_status == kPromptCompleted) {
|
||||
continue; // avoid reading line before last token loads
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue