Fix llama.com interactive mode regressions

This commit is contained in:
Justine Tunney 2023-05-13 00:09:38 -07:00
parent fd34ef732d
commit 4a8a81eb9f
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
3 changed files with 67 additions and 22 deletions

View file

@ -24,16 +24,12 @@
#include "libc/limits.h" #include "libc/limits.h"
#include "libc/macros.internal.h" #include "libc/macros.internal.h"
#ifdef DescribeIovec
#undef DescribeIovec
#endif
#define N 300 #define N 300
#define append(...) o += ksnprintf(buf + o, N - o, __VA_ARGS__) #define append(...) o += ksnprintf(buf + o, N - o, __VA_ARGS__)
const char *DescribeIovec(char buf[N], ssize_t rc, const struct iovec *iov, const char *(DescribeIovec)(char buf[N], ssize_t rc, const struct iovec *iov,
int iovlen) { int iovlen) {
const char *d; const char *d;
int i, j, o = 0; int i, j, o = 0;

View file

@ -338,7 +338,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "%s: No prompt specified\n", __func__); fprintf(stderr, "%s: No prompt specified\n", __func__);
fprintf(stderr, "%s: Loading CompanionAI\n", __func__); fprintf(stderr, "%s: Loading CompanionAI\n", __func__);
} }
append_file_to_prompt("/zip/companionai.txt", params); if (fileexists("third_party/ggml/companionai.txt")) {
append_file_to_prompt("third_party/ggml/companionai.txt", params);
} else {
append_file_to_prompt("/zip/companionai.txt", params);
}
const char *user; const char *user;
user = getenv("USER"); user = getenv("USER");
if (!user || !*user) { if (!user || !*user) {

View file

@ -30,11 +30,12 @@
#include "libc/calls/calls.h" #include "libc/calls/calls.h"
#include "libc/calls/struct/sigaction.h" #include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/stat.h" #include "libc/calls/struct/stat.h"
#include "libc/fmt/fmt.h"
#include "libc/intrin/bits.h" #include "libc/intrin/bits.h"
#include "libc/intrin/kprintf.h"
#include "libc/log/log.h" #include "libc/log/log.h"
#include "libc/macros.internal.h" #include "libc/macros.internal.h"
#include "libc/nexgen32e/x86feature.h" #include "libc/nexgen32e/x86feature.h"
#include "libc/runtime/runtime.h"
#include "libc/stdio/stdio.h" #include "libc/stdio/stdio.h"
#include "libc/sysv/consts/map.h" #include "libc/sysv/consts/map.h"
#include "libc/sysv/consts/msync.h" #include "libc/sysv/consts/msync.h"
@ -61,6 +62,12 @@ static gpt_params params;
static llama_context * ctx; static llama_context * ctx;
static console_state con_st; static console_state con_st;
static int n_past;
static int n_remain;
static int n_consumed;
static bool input_noecho;
static bool is_antiprompt;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
static std::atomic<bool> is_stalled; static std::atomic<bool> is_stalled;
@ -93,6 +100,24 @@ static int CompareTime(struct timespec a, struct timespec b) {
return cmp; return cmp;
} }
////////////////////////////////////////////////////////////////////////////////
// ux explanatory logging for llama.com developers
#if 1
#define DEVLOG(...) (void)0
#else
#define DEVLOG(...) if (g_devlog) fprintf(g_devlog, __VA_ARGS__)
static FILE *g_devlog;
__attribute__((__constructor__)) static void init(void) {
char path[PATH_MAX];
static char linebuf[4096];
snprintf(path, sizeof(path), "/tmp/llama-%s.log", getenv("USER"));
if ((g_devlog = fopen(path, "wa"))) {
setvbuf(g_devlog, linebuf, _IOLBF, sizeof(linebuf));
}
}
#endif
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
enum jtlp_status { enum jtlp_status {
@ -129,13 +154,19 @@ static void remember_init() {
longest_antiprompt += llama_longest_token(ctx) * 2; longest_antiprompt += llama_longest_token(ctx) * 2;
} }
static void remember_token(llama_token tok) { static void remember_token(llama_token tok,
bool is_user_input = false) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(tok); last_n_tokens.push_back(tok);
last_output.append(llama_token_to_str(ctx, tok)); if (!is_user_input) {
if (last_output.size() > longest_antiprompt) { last_output.append(llama_token_to_str(ctx, tok));
last_output.erase(0, last_output.size() - longest_antiprompt); if (last_output.size() > longest_antiprompt) {
last_output.erase(0, last_output.size() - longest_antiprompt);
}
} }
DEVLOG("remember_token(%`'s, %d) -> %`'s\n",
llama_token_to_str(ctx, tok), is_user_input,
last_output.c_str());
} }
static bool has_antiprompt(std::string::size_type *out_index = nullptr, static bool has_antiprompt(std::string::size_type *out_index = nullptr,
@ -145,6 +176,8 @@ static bool has_antiprompt(std::string::size_type *out_index = nullptr,
if (index != std::string::npos) { if (index != std::string::npos) {
if (out_index) *out_index = index; if (out_index) *out_index = index;
if (out_antiprompt) *out_antiprompt = antiprompt; if (out_antiprompt) *out_antiprompt = antiprompt;
DEVLOG("found antiprompt %`'s at index %d of %`'s\n",
antiprompt.c_str(), (int)index, last_output.c_str());
return true; return true;
} }
} }
@ -406,12 +439,12 @@ int main(int argc, char ** argv) {
remember_init(); remember_init();
bool is_antiprompt = false; is_antiprompt = false;
bool input_noecho = params.verbose <= 0; input_noecho = params.verbose <= 0;
int n_past = 0; n_past = 0;
int n_remain = params.n_predict; n_remain = params.n_predict;
int n_consumed = 0; n_consumed = 0;
// instantly reload prompt if it's cached // instantly reload prompt if it's cached
int fd = open(params.prompt_path.c_str(), O_RDONLY); int fd = open(params.prompt_path.c_str(), O_RDONLY);
@ -532,6 +565,7 @@ int main(int argc, char ** argv) {
// perform evaluation // perform evaluation
if (embd.size() > 0) { if (embd.size() > 0) {
DEVLOG("perform evaluation embd.size()=%d\n", (int)embd.size());
if (n_past + (int) embd.size() > n_ctx) { if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep; n_past = n_keep;
embd.insert(embd.begin(), embd.insert(embd.begin(),
@ -544,6 +578,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
is_stalled = n_eval > 1; is_stalled = n_eval > 1;
DEVLOG("llama_eval(n_evel=%d, n_past=%d)\n", n_eval, n_past);
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
console_set_color(con_st, CONSOLE_COLOR_DEFAULT); console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
@ -645,6 +680,8 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token // out of user input, sample next token
DEVLOG("out of user input, sample next token w/ embd_inp.size()=%d n_consumed=%d\n",
(int)embd_inp.size(), n_consumed);
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
@ -738,10 +775,12 @@ int main(int argc, char ** argv) {
--n_remain; --n_remain;
} else { } else {
DEVLOG("some user input remains from prompt or interaction w/ embd_inp.size()=%d n_consumed=%d\n",
(int)embd_inp.size(), n_consumed);
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
remember_token(embd_inp[n_consumed++]); remember_token(embd_inp[n_consumed++], true);
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
} }
@ -791,13 +830,19 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
} }
} }
if (is_antiprompt && !params.interactive) { if (is_antiprompt) {
if (!got_newline) { if (!params.interactive) {
printf("\n"); if (!got_newline) {
printf("\n");
}
break;
} }
break; // scrub antiprompt so to detect it must be typed again
last_output.erase(0, ap_index + ap_text.size());
DEVLOG("scrubbed antiprompt -> %`'s\n", last_output.c_str());
} }
if (prompt_status == kPromptCompleted) { if (prompt_status == kPromptCompleted) {
DEVLOG("avoid reading line before last token loads\n");
continue; // avoid reading line before last token loads continue; // avoid reading line before last token loads
} }