Introduce prompt caching so prompts load instantly

This change also introduces an ephemeral status line in non-verbose mode
to display a load percentage status when slow operations are happening.
This commit is contained in:
Justine Tunney 2023-04-28 16:15:26 -07:00
parent bf6459e324
commit b31ba86ace
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
7 changed files with 333 additions and 103 deletions

1
.gitignore vendored
View file

@ -1,6 +1,7 @@
# -*- conf -*- # -*- conf -*-
/o /o
/.prompt.jtlp
# TODO: Find some way to have Python write to o/ # TODO: Find some way to have Python write to o/
__pycache__ __pycache__

View file

@ -16,7 +16,9 @@ ORIGIN
LOCAL CHANGES LOCAL CHANGES
- Make it possible for loaded prompts to be cached to disk
- Introduce -v and --verbose flags - Introduce -v and --verbose flags
- Reduce batch size from 512 to 32
- Don't print stats / diagnostics unless -v is passed - Don't print stats / diagnostics unless -v is passed
- Reduce --top_p default from 0.95 to 0.70 - Reduce --top_p default from 0.95 to 0.70
- Change --reverse-prompt to no longer imply --interactive - Change --reverse-prompt to no longer imply --interactive

View file

@ -1,5 +1,5 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│ /*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
llama.cpp llama.cpp

View file

@ -23,7 +23,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_batch = 32; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_keep = 0; // number of tokens to keep from initial prompt
// sampling parameters // sampling parameters
@ -34,6 +34,7 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = ""; std::string prompt = "";
std::string prompt_path = ".prompt.jtlp";
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
@ -42,7 +43,7 @@ struct gpt_params {
bool memory_f16 = true; // use f16 instead of f32 for memory kv bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs bool use_color = isatty(1) == 1; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding

View file

@ -72,6 +72,7 @@ THIRD_PARTY_GGML_LLAMA_DIRECTDEPS = \
LIBC_NEXGEN32E \ LIBC_NEXGEN32E \
LIBC_RUNTIME \ LIBC_RUNTIME \
LIBC_STDIO \ LIBC_STDIO \
LIBC_LOG \
LIBC_STR \ LIBC_STR \
LIBC_STUBS \ LIBC_STUBS \
LIBC_SYSV \ LIBC_SYSV \

View file

@ -1,5 +1,5 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│ /*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
llama.cpp llama.cpp
@ -25,6 +25,30 @@
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
#include "third_party/ggml/llama.h"
#include "libc/intrin/bits.h"
#include "third_party/ggml/ggml.h"
#include "third_party/ggml/llama_util.h"
#include "third_party/libcxx/algorithm"
#include "third_party/libcxx/array"
#include "third_party/libcxx/atomic"
#include "third_party/libcxx/cassert"
#include "third_party/libcxx/cinttypes"
#include "third_party/libcxx/climits"
#include "third_party/libcxx/cstdint"
#include "third_party/libcxx/cstdio"
#include "third_party/libcxx/cstring"
#include "third_party/libcxx/ctime"
#include "third_party/libcxx/fstream"
#include "third_party/libcxx/initializer_list"
#include "third_party/libcxx/map"
#include "third_party/libcxx/memory"
#include "third_party/libcxx/mutex"
#include "third_party/libcxx/queue"
#include "third_party/libcxx/random"
#include "third_party/libcxx/sstream"
#include "third_party/libcxx/thread"
#include "third_party/libcxx/unordered_map"
asm(".ident\t\"\\n\\n\ asm(".ident\t\"\\n\\n\
llama.cpp (MIT License)\\n\ llama.cpp (MIT License)\\n\
@ -32,46 +56,9 @@ Copyright (c) 2023 Georgi Gerganov\"");
asm(".include \"libc/disclaimer.inc\""); asm(".include \"libc/disclaimer.inc\"");
// clang-format off // clang-format off
// Defines fileno on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#include "third_party/libcxx/cstdint"
#include "third_party/libcxx/cstdio"
#endif
#include "third_party/ggml/llama_util.h"
#include "third_party/ggml/llama.h"
#include "third_party/ggml/ggml.h"
#include "third_party/libcxx/array"
#include "third_party/libcxx/ctime"
#include "third_party/libcxx/cinttypes"
#include "third_party/libcxx/fstream"
#include "third_party/libcxx/random"
#include "third_party/libcxx/map"
#include "third_party/libcxx/unordered_map"
#include "third_party/libcxx/queue"
#include "third_party/libcxx/cassert"
#include "third_party/libcxx/cstring"
#include "third_party/libcxx/climits"
#include "third_party/libcxx/memory"
#include "third_party/libcxx/algorithm"
#include "third_party/libcxx/initializer_list"
#include "third_party/libcxx/thread"
#include "third_party/libcxx/atomic"
#include "third_party/libcxx/mutex"
#include "third_party/libcxx/sstream"
#define LLAMA_USE_SCRATCH #define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
#define READ32BE(s) \
((uint32_t)((const uint8_t *)(s))[0] << 030 | \
(uint32_t)((const uint8_t *)(s))[1] << 020 | \
(uint32_t)((const uint8_t *)(s))[2] << 010 | \
(uint32_t)((const uint8_t *)(s))[3] << 000)
// available llama models // available llama models
enum e_model { enum e_model {
MODEL_UNKNOWN, MODEL_UNKNOWN,

View file

@ -1,5 +1,5 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│ /*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
llama.cpp llama.cpp
@ -25,6 +25,23 @@
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
#include "libc/assert.h"
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/stat.h"
#include "libc/intrin/bits.h"
#include "libc/log/log.h"
#include "libc/nexgen32e/x86feature.h"
#include "libc/sysv/consts/map.h"
#include "libc/sysv/consts/msync.h"
#include "libc/sysv/consts/o.h"
#include "libc/sysv/consts/prot.h"
#include "libc/sysv/consts/sig.h"
#include "third_party/ggml/common.h"
#include "third_party/ggml/llama.h"
#include "third_party/ggml/llama_util.h"
#include "third_party/libcxx/iostream"
#include "third_party/libcxx/string"
#include "third_party/libcxx/vector"
asm(".ident\t\"\\n\\n\ asm(".ident\t\"\\n\\n\
llama.cpp (MIT License)\\n\ llama.cpp (MIT License)\\n\
@ -32,62 +49,13 @@ Copyright (c) 2023 Georgi Gerganov\"");
asm(".include \"libc/disclaimer.inc\""); asm(".include \"libc/disclaimer.inc\"");
// clang-format off // clang-format off
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "third_party/ggml/common.h"
#include "libc/nexgen32e/x86feature.h"
#include "third_party/ggml/llama.h"
#include "third_party/libcxx/cassert"
#include "third_party/libcxx/cinttypes"
#include "third_party/libcxx/cmath"
#include "third_party/libcxx/cstdio"
#include "third_party/libcxx/cstring"
#include "third_party/libcxx/ctime"
#include "third_party/libcxx/fstream"
#include "third_party/libcxx/iostream"
#include "third_party/libcxx/string"
#include "third_party/libcxx/vector"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include "libc/calls/calls.h"
#include "libc/calls/sigtimedwait.h"
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/siginfo.h"
#include "libc/sysv/consts/sa.h"
#include "libc/sysv/consts/sicode.h"
#include "libc/sysv/consts/ss.h"
#include "libc/calls/calls.h"
#include "libc/calls/weirdtypes.h"
#include "libc/runtime/pathconf.h"
#include "libc/runtime/runtime.h"
#include "libc/runtime/sysconf.h"
#include "libc/sysv/consts/f.h"
#include "libc/sysv/consts/fileno.h"
#include "libc/sysv/consts/o.h"
#include "libc/sysv/consts/ok.h"
#include "libc/time/time.h"
#include "third_party/getopt/getopt.h"
#include "third_party/musl/crypt.h"
#include "third_party/musl/lockf.h"
#elif defined (_WIN32)
#include "libc/calls/calls.h"
#include "libc/calls/sigtimedwait.h"
#include "libc/calls/struct/sigaction.h"
#include "libc/calls/struct/siginfo.h"
#include "libc/sysv/consts/sa.h"
#include "libc/sysv/consts/sicode.h"
#include "libc/sysv/consts/ss.h"
#endif
static console_state con_st; static console_state con_st;
static llama_context ** g_ctx; static llama_context ** g_ctx;
static bool is_interacting = false; static bool is_interacting = false;
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
@ -102,6 +70,14 @@ void sigint_handler(int signo) {
} }
#endif #endif
static int CompareTime(struct timespec a, struct timespec b) {
int cmp;
if (!(cmp = (a.tv_sec > b.tv_sec) - (a.tv_sec < b.tv_sec))) {
cmp = (a.tv_nsec > b.tv_nsec) - (a.tv_nsec < b.tv_nsec);
}
return cmp;
}
static int on_missing_feature(const char *name) { static int on_missing_feature(const char *name) {
fprintf(stderr, "error: we require %s support in your microprocessor.\n", name); fprintf(stderr, "error: we require %s support in your microprocessor.\n", name);
return 1; return 1;
@ -109,6 +85,9 @@ static int on_missing_feature(const char *name) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
ShowCrashReports();
params.model = "models/llama-7B/ggml-model.bin"; params.model = "models/llama-7B/ggml-model.bin";
if (!X86_HAVE(AVX2)) return on_missing_feature("avx2"); if (!X86_HAVE(AVX2)) return on_missing_feature("avx2");
@ -167,6 +146,7 @@ int main(int argc, char ** argv) {
//bool is_prime(int n) {)"; //bool is_prime(int n) {)";
llama_context * ctx; llama_context * ctx;
struct stat model_stat;
g_ctx = &ctx; g_ctx = &ctx;
// load the model // load the model
@ -182,8 +162,9 @@ int main(int argc, char ** argv) {
ctx = llama_init_from_file(params.model.c_str(), lparams, params.verbose); ctx = llama_init_from_file(params.model.c_str(), lparams, params.verbose);
if (ctx == NULL) { if (ctx == NULL || stat(params.model.c_str(), &model_stat)) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: failed to load model: %s\n",
params.model.c_str(), strerror(errno));
return 1; return 1;
} }
} }
@ -327,6 +308,28 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first; is_interacting = params.interactive_first;
} }
const uint32_t kJtlpMagic = READ32LE("jtlp");
const uint32_t kJtlpVersion = 0;
struct jtlp_header {
uint8_t magic[4];
uint8_t version[4];
uint8_t state_size[8];
uint8_t model_dev[8];
uint8_t model_ino[8];
uint8_t model_mtim_sec[8];
uint8_t model_mtim_nsec[8];
uint8_t prompt_size[8];
};
enum jtlp_status {
kPromptPending,
kPromptCompleted,
kPromptFinished
};
enum jtlp_status prompt_status = kPromptPending;
bool is_antiprompt = false; bool is_antiprompt = false;
bool input_noecho = !params.verbose; bool input_noecho = !params.verbose;
@ -334,13 +337,146 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict; int n_remain = params.n_predict;
int n_consumed = 0; int n_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly // instantly reload prompt if it's cached
set_console_color(con_st, CONSOLE_COLOR_PROMPT); int fd = open(params.prompt_path.c_str(), O_RDONLY);
if (fd != -1) {
size_t state_size;
size_t prompt_size;
struct timespec mtim;
struct jtlp_header *header;
off_t rc = lseek(fd, 0, SEEK_END);
LLAMA_ASSERT(rc != -1);
void *map = MAP_FAILED;
size_t file_size = rc;
if (file_size < sizeof(header)) {
fprintf(stderr, "%s: prompt file too small\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
map = mmap(0, file_size, PROT_READ, MAP_SHARED, fd, 0);
if (map == MAP_FAILED) {
fprintf(stderr, "%s: mmap failed: %s\n",
params.prompt_path.c_str(), strerror(errno));
goto CantReloadPrompt;
}
header = (struct jtlp_header *)map;
// check file format magic
if (READ32LE(header->magic) != kJtlpMagic) {
fprintf(stderr, "%s: prompt file has wrong magic\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
// check file format version
if (READ32LE(header->version) > kJtlpVersion) {
fprintf(stderr, "%s: prompt has future file format version\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
// check expected state size
state_size = llama_get_state_size(ctx);
if (READ64LE(header->state_size) != state_size) {
if (params.verbose) {
fprintf(stderr, "%s: prompt has stale data state size\n",
params.prompt_path.c_str());
}
goto CantReloadPrompt;
}
// check model device id
if (READ64LE(header->model_dev) != model_stat.st_dev) {
fprintf(stderr, "%s: prompt is for different model (dev)\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
// check model inode id
if (READ64LE(header->model_ino) != model_stat.st_ino) {
fprintf(stderr, "%s: prompt is for different model (ino)\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
// check model modified timestamp
mtim.tv_sec = READ64LE(header->model_mtim_sec);
mtim.tv_nsec = READ64LE(header->model_mtim_nsec);
if (CompareTime(model_stat.st_mtim, mtim) > 0) {
if (params.verbose) {
fprintf(stderr, "%s: model file timestamp changed; will reload and regenerate prompt\n",
params.prompt_path.c_str());
}
goto CantReloadPrompt;
}
// check prompt file size
prompt_size = READ64LE(header->prompt_size);
if (sizeof(struct jtlp_header) + prompt_size + state_size > file_size) {
fprintf(stderr, "%s: prompt file size unexpected\n",
params.prompt_path.c_str());
goto CantReloadPrompt;
}
// check prompt textus
if (prompt_size != params.prompt.size() ||
memcmp(header + 1, params.prompt.c_str(), prompt_size) != 0) {
if (params.verbose) {
fprintf(stderr, "%s: prompt text changed; will reload and regenerate\n",
params.prompt_path.c_str());
}
goto CantReloadPrompt;
}
// read the transformer state
llama_set_state_data(ctx, (uint8_t *)(header + 1) + prompt_size);
// we're finished loading the prompt file
if (params.verbose) {
fprintf(stderr, "%s: %s: reloaded previously saved prompt\n",
__func__, params.prompt_path.c_str());
}
// now setup the business logic
llama_set_rng_seed(ctx, params.seed);
while ((int) embd_inp.size() > n_consumed) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed++]);
}
n_past = n_consumed;
prompt_status = kPromptFinished;
if (params.interactive) {
is_interacting = true;
fflush(stdout);
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(),
last_output.length() - antiprompt.length(),
antiprompt.length()) != std::string::npos) {
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
printf("%s", antiprompt.c_str());
fflush(stdout);
break;
}
}
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
CantReloadPrompt:
if (map != MAP_FAILED) {
munmap(map, file_size);
}
close(fd);
}
if (prompt_status == kPromptPending && params.verbose) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
}
std::vector<llama_token> embd; std::vector<llama_token> embd;
if (prompt_status == kPromptPending &&
!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL("loading model..."));
fflush(stderr);
}
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict
// performance inference evaluation of scheduled tokens
// this loads prompt tokens and it also does prediction
if (embd.size() > 0) { if (embd.size() > 0) {
// infinite text generation via context swapping // infinite text generation via context swapping
// if we run out of context: // if we run out of context:
@ -375,11 +511,102 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
n_past += n_eval; n_past += n_eval;
if (prompt_status == kPromptPending &&
!params.verbose && con_st.use_color && embd_inp.size()) {
fprintf(stderr, EPHEMERAL("loading prompt %d%% ..."),
(int)(n_consumed / (double)embd_inp.size() * 100));
fflush(stderr);
}
} }
} }
embd.clear(); embd.clear();
// save prompt to disk atomically as soon as it's finished loading
bool was_completed = prompt_status == kPromptCompleted;
if (was_completed && !params.prompt_path.empty()) {
int fd = -1;
int close_rc;
uint8_t buf[8];
size_t file_size;
size_t state_size;
std::string tmppath;
void *map = MAP_FAILED;
struct jtlp_header header;
if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL("caching prompt..."));
fflush(stderr);
}
state_size = llama_get_state_size(ctx);
WRITE32LE(header.magic, kJtlpMagic);
WRITE32LE(header.version, kJtlpVersion);
WRITE64LE(header.state_size, state_size);
WRITE64LE(header.model_dev, model_stat.st_dev);
WRITE64LE(header.model_ino, model_stat.st_ino);
WRITE64LE(header.model_mtim_sec, model_stat.st_mtim.tv_sec);
WRITE64LE(header.model_mtim_nsec, model_stat.st_mtim.tv_nsec);
WRITE64LE(header.prompt_size, params.prompt.size());
file_size = sizeof(header) + params.prompt.size() + state_size;
tmppath.append(params.prompt_path);
tmppath.append(".XXXXXX");
fd = mkstemp(&tmppath[0]);
if (fd == -1) {
fprintf(stderr, "%s: mkstemp failed: %s\n",
tmppath.c_str(), strerror(errno));
goto CouldNotSavePrompt;
}
if (ftruncate(fd, file_size)) {
fprintf(stderr, "%s: ftruncate failed: %s\n",
tmppath.c_str(), strerror(errno));
goto CouldNotSavePrompt;
}
map = mmap(0, file_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (map == MAP_FAILED) {
fprintf(stderr, "%s: mmap failed: %s\n",
tmppath.c_str(), strerror(errno));
goto CouldNotSavePrompt;
}
llama_copy_state_data(ctx, (uint8_t *)map + sizeof(header) + params.prompt.size());
memcpy((uint8_t *)map + sizeof(header), params.prompt.c_str(), params.prompt.size());
memcpy(map, &header, sizeof(header));
if (msync(map, file_size, MS_ASYNC) && params.verbose) {
fprintf(stderr, "%s: msync failed: %s\n",
tmppath.c_str(), strerror(errno));
}
if (munmap(map, file_size) && params.verbose) {
fprintf(stderr, "%s: munmap failed: %s\n",
tmppath.c_str(), strerror(errno));
}
map = MAP_FAILED;
close_rc = close(fd);
fd = -1;
if (close_rc) {
fprintf(stderr, "%s: close failed: %s\n",
tmppath.c_str(), strerror(errno));
goto CouldNotSavePrompt;
}
if (rename(tmppath.c_str(), params.prompt_path.c_str())) {
fprintf(stderr, "%s -> %s: rename failed: %s\n",
tmppath.c_str(), params.prompt_path.c_str(), strerror(errno));
goto CouldNotSavePrompt;
}
tmppath.clear();
CouldNotSavePrompt:
if (map != MAP_FAILED) munmap(map, file_size);
if (fd != -1) close(fd);
if (!tmppath.empty()) unlink(tmppath.c_str());
}
if (was_completed) {
if (!params.verbose && con_st.use_color) {
fprintf(stderr, EPHEMERAL(""));
fflush(stderr);
}
if (params.interactive) {
is_interacting = true;
}
prompt_status = kPromptFinished;
}
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token // out of user input, sample next token
const int32_t top_k = params.top_k; const int32_t top_k = params.top_k;
@ -422,17 +649,23 @@ int main(int argc, char ** argv) {
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
} else { } else {
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]); last_n_tokens.push_back(embd_inp[n_consumed++]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
} }
} }
// we've nearly finished loading the prompt
if (prompt_status == kPromptPending &&
(int) embd_inp.size() <= n_consumed) {
prompt_status = kPromptCompleted;
}
} }
// checks for reverse prompt // checks for reverse prompt
@ -476,6 +709,10 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
} }
if (prompt_status == kPromptCompleted) {
continue; // avoid reading line before last token loads
}
// reset color to default if we there is no pending user input // reset color to default if we there is no pending user input
if (params.verbose && !input_noecho && (int)embd_inp.size() == n_consumed) { if (params.verbose && !input_noecho && (int)embd_inp.size() == n_consumed) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT); set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
@ -521,6 +758,7 @@ int main(int argc, char ** argv) {
} }
win32_utf8_encode(wline, line); win32_utf8_encode(wline, line);
#else #else
fflush(stdout);
if (!std::getline(std::cin, line)) { if (!std::getline(std::cin, line)) {
// input stream is bad or EOF received // input stream is bad or EOF received
return 0; return 0;