From 4a8a81eb9fdeaf59a6b6c417a366b6e4efbf5d27 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Sat, 13 May 2023 00:09:38 -0700 Subject: [PATCH] Fix llama.com interactive mode regressions --- libc/intrin/describeiovec.c | 8 +--- third_party/ggml/common.cc | 6 ++- third_party/ggml/main.cc | 75 +++++++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/libc/intrin/describeiovec.c b/libc/intrin/describeiovec.c index bd8ae3b03..364736a7e 100644 --- a/libc/intrin/describeiovec.c +++ b/libc/intrin/describeiovec.c @@ -24,16 +24,12 @@ #include "libc/limits.h" #include "libc/macros.internal.h" -#ifdef DescribeIovec -#undef DescribeIovec -#endif - #define N 300 #define append(...) o += ksnprintf(buf + o, N - o, __VA_ARGS__) -const char *DescribeIovec(char buf[N], ssize_t rc, const struct iovec *iov, - int iovlen) { +const char *(DescribeIovec)(char buf[N], ssize_t rc, const struct iovec *iov, + int iovlen) { const char *d; int i, j, o = 0; diff --git a/third_party/ggml/common.cc b/third_party/ggml/common.cc index e75dfd9f4..b4270e60c 100644 --- a/third_party/ggml/common.cc +++ b/third_party/ggml/common.cc @@ -338,7 +338,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { fprintf(stderr, "%s: No prompt specified\n", __func__); fprintf(stderr, "%s: Loading CompanionAI\n", __func__); } - append_file_to_prompt("/zip/companionai.txt", params); + if (fileexists("third_party/ggml/companionai.txt")) { + append_file_to_prompt("third_party/ggml/companionai.txt", params); + } else { + append_file_to_prompt("/zip/companionai.txt", params); + } const char *user; user = getenv("USER"); if (!user || !*user) { diff --git a/third_party/ggml/main.cc b/third_party/ggml/main.cc index cda679bb4..cb8f1b845 100644 --- a/third_party/ggml/main.cc +++ b/third_party/ggml/main.cc @@ -30,11 +30,12 @@ #include "libc/calls/calls.h" #include "libc/calls/struct/sigaction.h" #include "libc/calls/struct/stat.h" +#include "libc/fmt/fmt.h" #include "libc/intrin/bits.h" -#include "libc/intrin/kprintf.h" #include "libc/log/log.h" #include "libc/macros.internal.h" #include "libc/nexgen32e/x86feature.h" +#include "libc/runtime/runtime.h" #include "libc/stdio/stdio.h" #include "libc/sysv/consts/map.h" #include "libc/sysv/consts/msync.h" @@ -61,6 +62,12 @@ static gpt_params params; static llama_context * ctx; static console_state con_st; +static int n_past; +static int n_remain; +static int n_consumed; +static bool input_noecho; +static bool is_antiprompt; + //////////////////////////////////////////////////////////////////////////////// static std::atomic is_stalled; @@ -93,6 +100,24 @@ static int CompareTime(struct timespec a, struct timespec b) { return cmp; } +//////////////////////////////////////////////////////////////////////////////// +// ux explanatory logging for llama.com developers + +#if 1 +#define DEVLOG(...) (void)0 +#else +#define DEVLOG(...) if (g_devlog) fprintf(g_devlog, __VA_ARGS__) +static FILE *g_devlog; +__attribute__((__constructor__)) static void init(void) { + char path[PATH_MAX]; + static char linebuf[4096]; + snprintf(path, sizeof(path), "/tmp/llama-%s.log", getenv("USER")); + if ((g_devlog = fopen(path, "wa"))) { + setvbuf(g_devlog, linebuf, _IOLBF, sizeof(linebuf)); + } +} +#endif + //////////////////////////////////////////////////////////////////////////////// enum jtlp_status { @@ -129,13 +154,19 @@ static void remember_init() { longest_antiprompt += llama_longest_token(ctx) * 2; } -static void remember_token(llama_token tok) { +static void remember_token(llama_token tok, + bool is_user_input = false) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(tok); - last_output.append(llama_token_to_str(ctx, tok)); - if (last_output.size() > longest_antiprompt) { - last_output.erase(0, last_output.size() - longest_antiprompt); + if (!is_user_input) { + last_output.append(llama_token_to_str(ctx, tok)); + if (last_output.size() > longest_antiprompt) { + last_output.erase(0, last_output.size() - longest_antiprompt); + } } + DEVLOG("remember_token(%`'s, %d) -> %`'s\n", + llama_token_to_str(ctx, tok), is_user_input, + last_output.c_str()); } static bool has_antiprompt(std::string::size_type *out_index = nullptr, @@ -145,6 +176,8 @@ static bool has_antiprompt(std::string::size_type *out_index = nullptr, if (index != std::string::npos) { if (out_index) *out_index = index; if (out_antiprompt) *out_antiprompt = antiprompt; + DEVLOG("found antiprompt %`'s at index %d of %`'s\n", + antiprompt.c_str(), (int)index, last_output.c_str()); return true; } } @@ -406,12 +439,12 @@ int main(int argc, char ** argv) { remember_init(); - bool is_antiprompt = false; - bool input_noecho = params.verbose <= 0; + is_antiprompt = false; + input_noecho = params.verbose <= 0; - int n_past = 0; - int n_remain = params.n_predict; - int n_consumed = 0; + n_past = 0; + n_remain = params.n_predict; + n_consumed = 0; // instantly reload prompt if it's cached int fd = open(params.prompt_path.c_str(), O_RDONLY); @@ -532,6 +565,7 @@ int main(int argc, char ** argv) { // perform evaluation if (embd.size() > 0) { + DEVLOG("perform evaluation embd.size()=%d\n", (int)embd.size()); if (n_past + (int) embd.size() > n_ctx) { n_past = n_keep; embd.insert(embd.begin(), @@ -544,6 +578,7 @@ int main(int argc, char ** argv) { n_eval = params.n_batch; } is_stalled = n_eval > 1; + DEVLOG("llama_eval(n_evel=%d, n_past=%d)\n", n_eval, n_past); if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); console_set_color(con_st, CONSOLE_COLOR_DEFAULT); @@ -645,6 +680,8 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token + DEVLOG("out of user input, sample next token w/ embd_inp.size()=%d n_consumed=%d\n", + (int)embd_inp.size(), n_consumed); const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const float top_p = params.top_p; @@ -738,10 +775,12 @@ int main(int argc, char ** argv) { --n_remain; } else { + DEVLOG("some user input remains from prompt or interaction w/ embd_inp.size()=%d n_consumed=%d\n", + (int)embd_inp.size(), n_consumed); // some user input remains from prompt or interaction, forward it to processing while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - remember_token(embd_inp[n_consumed++]); + remember_token(embd_inp[n_consumed++], true); if ((int) embd.size() >= params.n_batch) { break; } @@ -791,13 +830,19 @@ int main(int argc, char ** argv) { fflush(stdout); } } - if (is_antiprompt && !params.interactive) { - if (!got_newline) { - printf("\n"); + if (is_antiprompt) { + if (!params.interactive) { + if (!got_newline) { + printf("\n"); + } + break; } - break; + // scrub antiprompt so to detect it must be typed again + last_output.erase(0, ap_index + ap_text.size()); + DEVLOG("scrubbed antiprompt -> %`'s\n", last_output.c_str()); } if (prompt_status == kPromptCompleted) { + DEVLOG("avoid reading line before last token loads\n"); continue; // avoid reading line before last token loads }