Fix llama.com interactive mode regressions

This commit is contained in:
Justine Tunney 2023-05-13 00:09:38 -07:00
parent fd34ef732d
commit 4a8a81eb9f
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
3 changed files with 67 additions and 22 deletions

View file

@ -24,15 +24,11 @@
#include "libc/limits.h"
#include "libc/macros.internal.h"
#ifdef DescribeIovec
#undef DescribeIovec
#endif
#define N 300
#define append(...) o += ksnprintf(buf + o, N - o, __VA_ARGS__)
const char *DescribeIovec(char buf[N], ssize_t rc, const struct iovec *iov,
const char *(DescribeIovec)(char buf[N], ssize_t rc, const struct iovec *iov,
int iovlen) {
const char *d;
int i, j, o = 0;

View file

@ -338,7 +338,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "%s: No prompt specified\n", __func__);
fprintf(stderr, "%s: Loading CompanionAI\n", __func__);
}
if (fileexists("third_party/ggml/companionai.txt")) {
append_file_to_prompt("third_party/ggml/companionai.txt", params);
} else {
append_file_to_prompt("/zip/companionai.txt", params);
}
const char *user;
user = getenv("USER");
if (!user || !*user) {

View file

@ -30,11 +30,12 @@
#include "libc/calls/calls.h"
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/stat.h"
#include "libc/fmt/fmt.h"
#include "libc/intrin/bits.h"
#include "libc/intrin/kprintf.h"
#include "libc/log/log.h"
#include "libc/macros.internal.h"
#include "libc/nexgen32e/x86feature.h"
#include "libc/runtime/runtime.h"
#include "libc/stdio/stdio.h"
#include "libc/sysv/consts/map.h"
#include "libc/sysv/consts/msync.h"
@ -61,6 +62,12 @@ static gpt_params params;
static llama_context * ctx;
static console_state con_st;
static int n_past;
static int n_remain;
static int n_consumed;
static bool input_noecho;
static bool is_antiprompt;
////////////////////////////////////////////////////////////////////////////////
static std::atomic<bool> is_stalled;
@ -93,6 +100,24 @@ static int CompareTime(struct timespec a, struct timespec b) {
return cmp;
}
////////////////////////////////////////////////////////////////////////////////
// ux explanatory logging for llama.com developers
#if 1
#define DEVLOG(...) (void)0
#else
#define DEVLOG(...) if (g_devlog) fprintf(g_devlog, __VA_ARGS__)
static FILE *g_devlog;
__attribute__((__constructor__)) static void init(void) {
char path[PATH_MAX];
static char linebuf[4096];
snprintf(path, sizeof(path), "/tmp/llama-%s.log", getenv("USER"));
if ((g_devlog = fopen(path, "wa"))) {
setvbuf(g_devlog, linebuf, _IOLBF, sizeof(linebuf));
}
}
#endif
////////////////////////////////////////////////////////////////////////////////
enum jtlp_status {
@ -129,14 +154,20 @@ static void remember_init() {
longest_antiprompt += llama_longest_token(ctx) * 2;
}
static void remember_token(llama_token tok) {
static void remember_token(llama_token tok,
bool is_user_input = false) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(tok);
if (!is_user_input) {
last_output.append(llama_token_to_str(ctx, tok));
if (last_output.size() > longest_antiprompt) {
last_output.erase(0, last_output.size() - longest_antiprompt);
}
}
DEVLOG("remember_token(%`'s, %d) -> %`'s\n",
llama_token_to_str(ctx, tok), is_user_input,
last_output.c_str());
}
static bool has_antiprompt(std::string::size_type *out_index = nullptr,
std::string *out_antiprompt = nullptr) {
@ -145,6 +176,8 @@ static bool has_antiprompt(std::string::size_type *out_index = nullptr,
if (index != std::string::npos) {
if (out_index) *out_index = index;
if (out_antiprompt) *out_antiprompt = antiprompt;
DEVLOG("found antiprompt %`'s at index %d of %`'s\n",
antiprompt.c_str(), (int)index, last_output.c_str());
return true;
}
}
@ -406,12 +439,12 @@ int main(int argc, char ** argv) {
remember_init();
bool is_antiprompt = false;
bool input_noecho = params.verbose <= 0;
is_antiprompt = false;
input_noecho = params.verbose <= 0;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
n_past = 0;
n_remain = params.n_predict;
n_consumed = 0;
// instantly reload prompt if it's cached
int fd = open(params.prompt_path.c_str(), O_RDONLY);
@ -532,6 +565,7 @@ int main(int argc, char ** argv) {
// perform evaluation
if (embd.size() > 0) {
DEVLOG("perform evaluation embd.size()=%d\n", (int)embd.size());
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
embd.insert(embd.begin(),
@ -544,6 +578,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
is_stalled = n_eval > 1;
DEVLOG("llama_eval(n_evel=%d, n_past=%d)\n", n_eval, n_past);
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
@ -645,6 +680,8 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
DEVLOG("out of user input, sample next token w/ embd_inp.size()=%d n_consumed=%d\n",
(int)embd_inp.size(), n_consumed);
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
@ -738,10 +775,12 @@ int main(int argc, char ** argv) {
--n_remain;
} else {
DEVLOG("some user input remains from prompt or interaction w/ embd_inp.size()=%d n_consumed=%d\n",
(int)embd_inp.size(), n_consumed);
// some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
remember_token(embd_inp[n_consumed++]);
remember_token(embd_inp[n_consumed++], true);
if ((int) embd.size() >= params.n_batch) {
break;
}
@ -791,13 +830,19 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
}
if (is_antiprompt && !params.interactive) {
if (is_antiprompt) {
if (!params.interactive) {
if (!got_newline) {
printf("\n");
}
break;
}
// scrub antiprompt so to detect it must be typed again
last_output.erase(0, ap_index + ap_text.size());
DEVLOG("scrubbed antiprompt -> %`'s\n", last_output.c_str());
}
if (prompt_status == kPromptCompleted) {
DEVLOG("avoid reading line before last token loads\n");
continue; // avoid reading line before last token loads
}