Use Companion AI in llama.com by default

This commit is contained in:
Justine Tunney 2023-04-29 00:48:14 -07:00
parent d9e27203d4
commit 3dac9f8999
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
8 changed files with 310 additions and 193 deletions

View file

@ -2,7 +2,8 @@
vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
llama.cpp
llama.com
Copyright (c) 2023 Justine Alexandra Roberts Tunney
Copyright (c) 2023 Georgi Gerganov
Permission is hereby granted, free of charge, to any person obtaining
@ -26,11 +27,13 @@
*/
#include "libc/assert.h"
#include "libc/calls/calls.h"
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/stat.h"
#include "libc/intrin/bits.h"
#include "libc/log/log.h"
#include "libc/nexgen32e/x86feature.h"
#include "libc/stdio/stdio.h"
#include "libc/sysv/consts/map.h"
#include "libc/sysv/consts/msync.h"
#include "libc/sysv/consts/o.h"
@ -57,7 +60,6 @@ static bool is_interacting = false;
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
@ -65,6 +67,7 @@ void sigint_handler(int signo) {
if (!is_interacting) {
is_interacting=true;
} else {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
if (g_verbose) {
llama_print_timings(*g_ctx);
}
@ -72,7 +75,6 @@ void sigint_handler(int signo) {
}
}
}
#endif
static int CompareTime(struct timespec a, struct timespec b) {
int cmp;
@ -83,7 +85,9 @@ static int CompareTime(struct timespec a, struct timespec b) {
}
static int on_missing_feature(const char *name) {
fprintf(stderr, "error: we require %s support in your microprocessor.\n", name);
fprintf(stderr, "%s: error: cpuid %s not detected\n", __func__, name);
fprintf(stderr, "%s: amd microprocessors made after 2017 usually work\n", __func__);
fprintf(stderr, "%s: intel microprocessors made after 2013 usually work\n", __func__);
return 1;
}
@ -91,15 +95,19 @@ int main(int argc, char ** argv) {
gpt_params params;
ShowCrashReports();
setvbuf(stderr, NULL, _IONBF, 0);
params.model = "models/llama-7B/ggml-model.bin";
if (!X86_HAVE(AVX2)) return on_missing_feature("avx2");
if (!X86_HAVE(AVX)) return on_missing_feature("avx");
if (!X86_HAVE(FMA)) return on_missing_feature("fma");
if (!X86_HAVE(F16C)) return on_missing_feature("f16c");
if (!X86_HAVE(SSE3)) return on_missing_feature("sse3");
if (!X86_HAVE(F16C)) {
fprintf(stderr, "%s: warning: cpuid f16c not detected; inference might crash\n", __func__);
}
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
@ -108,10 +116,6 @@ int main(int argc, char ** argv) {
// (note for later: this is a slightly awkward choice)
con_st.use_color = params.use_color;
#if defined (_WIN32)
win32_console_init(params.use_color);
#endif
g_verbose = params.verbose;
if (params.perplexity) {
@ -228,8 +232,20 @@ int main(int argc, char ** argv) {
}
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
int n_keep = params.n_keep;
if (n_keep < 0 || n_keep > (int)embd_inp.size() || params.instruct) {
n_keep = (int)embd_inp.size();
}
if (!n_keep && !params.n_keep_str.empty()) {
auto pivot = ::llama_tokenize(ctx, params.n_keep_str, false);
auto pos = std::search(embd_inp.begin(), embd_inp.end(),
pivot.begin(), pivot.end());
if (pos == embd_inp.end()) {
fprintf(stderr, "%s: error: --n_keep %`'s substring not found within prompt\n",
__func__, params.n_keep_str.c_str());
return 1;
}
n_keep = (pos - embd_inp.begin()) + (pivot.end() - pivot.begin());
}
// prefix & suffix for instruct mode
@ -255,28 +271,27 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
fprintf(stderr, "%6d %6d -> %`'s\n", i, embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
}
if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
}
fprintf(stderr, "'\n");
fprintf(stderr, "%s: first part of prompt: \"", __func__);
for (int i = 0; i < n_keep; i++) {
fprintf(stderr, "%'s", llama_token_to_str(ctx, embd_inp[i]));
}
fprintf(stderr, "\"\n");
fprintf(stderr, "%s: second part of prompt: \"", __func__);
for (int i = n_keep; i < embd_inp.size(); i++) {
fprintf(stderr, "%'s", llama_token_to_str(ctx, embd_inp[i]));
}
fprintf(stderr, "\"\n");
fprintf(stderr, "\n");
}
if (params.interactive) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
signal(SIGINT, sigint_handler);
#endif
if (params.verbose) {
fprintf(stderr, "%s: interactive mode on.\n", __func__);
@ -292,11 +307,12 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
if (params.verbose) {
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n",
n_ctx, params.n_batch, params.n_predict, n_keep);
fprintf(stderr, "\n\n");
}
@ -306,9 +322,7 @@ int main(int argc, char ** argv) {
if (params.verbose && params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n"
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_first;
@ -442,15 +456,12 @@ int main(int argc, char ** argv) {
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
fflush(stdout);
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(),
last_output.length() - antiprompt.length(),
antiprompt.length()) != std::string::npos) {
auto toks = ::llama_tokenize(ctx, antiprompt, false);
if (std::equal(last_n_tokens.end() - toks.size(),
last_n_tokens.end(),
toks.begin(),
toks.end())) {
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
@ -475,38 +486,19 @@ int main(int argc, char ** argv) {
if (prompt_status == kPromptPending &&
!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL("loading model..."));
fflush(stderr);
fprintf(stderr, EPHEMERAL("loading weights..."));
}
while (n_remain != 0 || params.interactive) {
// performance inference evaluation of scheduled tokens
// this loads prompt tokens and it also does prediction
// perform evaluation
if (embd.size() > 0) {
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() > n_ctx) {
const int n_left = n_past - params.n_keep;
n_past = params.n_keep;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
n_past = n_keep;
embd.insert(embd.begin(),
last_n_tokens.end() - (n_past - n_keep) / 2 - embd.size(),
last_n_tokens.end() - embd.size());
}
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
@ -514,6 +506,7 @@ int main(int argc, char ** argv) {
}
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 1;
}
n_past += n_eval;
@ -521,13 +514,11 @@ int main(int argc, char ** argv) {
!params.verbose && con_st.use_color && embd_inp.size()) {
fprintf(stderr, EPHEMERAL("loading prompt %d%% ..."),
(int)(n_consumed / (double)embd_inp.size() * 100));
fflush(stderr);
}
}
embd.clear();
}
embd.clear();
// save prompt to disk atomically as soon as it's finished loading
bool was_completed = prompt_status == kPromptCompleted;
if (was_completed && !params.prompt_path.empty()) {
@ -541,7 +532,6 @@ int main(int argc, char ** argv) {
struct jtlp_header header;
if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL("caching prompt..."));
fflush(stderr);
}
state_size = llama_get_state_size(ctx);
WRITE32LE(header.magic, kJtlpMagic);
@ -605,12 +595,30 @@ int main(int argc, char ** argv) {
if (was_completed) {
if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL(""));
fflush(stderr);
}
if (params.interactive) {
is_interacting = true;
}
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
fflush(stdout);
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(),
last_output.length() - antiprompt.length(),
antiprompt.length()) != std::string::npos) {
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
}
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
@ -735,14 +743,10 @@ int main(int argc, char ** argv) {
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
#if defined (_WIN32)
// Windows: must reactivate sigint handler after each signal
signal(SIGINT, sigint_handler);
#endif
if (params.instruct) {
printf("\n> ");
}
@ -753,23 +757,25 @@ int main(int argc, char ** argv) {
printf("%s", buffer.c_str());
}
// display a "waiting for input" indicator, just in case
// the model doesn't halt on the antiprompt.
if (con_st.use_color) {
fprintf(stdout, "?\b");
fflush(stdout);
}
std::string line;
bool another_line = true;
do {
#if defined(_WIN32)
std::wstring wline;
if (!std::getline(std::wcin, wline)) {
// input stream is bad or EOF received
return 0;
}
win32_utf8_encode(wline, line);
#else
fflush(stdout);
if (!std::getline(std::cin, line)) {
// input stream is bad or EOF received
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
if (g_verbose) {
llama_print_timings(*g_ctx);
}
return 0;
}
#endif
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
@ -808,6 +814,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
is_interacting = false;
}
assert(!is_interacting);
}
// end of text token
@ -827,10 +834,6 @@ int main(int argc, char ** argv) {
}
}
#if defined (_WIN32)
signal(SIGINT, SIG_DFL);
#endif
if (params.verbose) {
llama_print_timings(ctx);
}