mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-02-07 06:53:33 +00:00
Introduce prompt caching so prompts load instantly
This change also introduces an ephemeral status line in non-verbose mode to display a load percentage status when slow operations are happening.
This commit is contained in:
parent
bf6459e324
commit
b31ba86ace
7 changed files with 333 additions and 103 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,6 +1,7 @@
|
|||
# -*- conf -*-
|
||||
|
||||
/o
|
||||
/.prompt.jtlp
|
||||
|
||||
# TODO: Find some way to have Python write to o/
|
||||
__pycache__
|
||||
|
|
2
third_party/ggml/README.cosmo
vendored
2
third_party/ggml/README.cosmo
vendored
|
@ -16,7 +16,9 @@ ORIGIN
|
|||
|
||||
LOCAL CHANGES
|
||||
|
||||
- Make it possible for loaded prompts to be cached to disk
|
||||
- Introduce -v and --verbose flags
|
||||
- Reduce batch size from 512 to 32
|
||||
- Don't print stats / diagnostics unless -v is passed
|
||||
- Reduce --top_p default from 0.95 to 0.70
|
||||
- Change --reverse-prompt to no longer imply --interactive
|
||||
|
|
4
third_party/ggml/common.cc
vendored
4
third_party/ggml/common.cc
vendored
|
@ -1,5 +1,5 @@
|
|||
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│
|
||||
│vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
/*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
|
||||
│vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
╚──────────────────────────────────────────────────────────────────────────────╝
|
||||
│ │
|
||||
│ llama.cpp │
|
||||
|
|
5
third_party/ggml/common.h
vendored
5
third_party/ggml/common.h
vendored
|
@ -23,7 +23,7 @@ struct gpt_params {
|
|||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_batch = 32; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
|
||||
// sampling parameters
|
||||
|
@ -34,6 +34,7 @@ struct gpt_params {
|
|||
|
||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
std::string prompt_path = ".prompt.jtlp";
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
||||
|
@ -42,7 +43,7 @@ struct gpt_params {
|
|||
|
||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool use_color = isatty(1) == 1; // use color to distinguish generations and inputs
|
||||
bool interactive = false; // interactive mode
|
||||
|
||||
bool embedding = false; // get only sentence embedding
|
||||
|
|
1
third_party/ggml/ggml.mk
vendored
1
third_party/ggml/ggml.mk
vendored
|
@ -72,6 +72,7 @@ THIRD_PARTY_GGML_LLAMA_DIRECTDEPS = \
|
|||
LIBC_NEXGEN32E \
|
||||
LIBC_RUNTIME \
|
||||
LIBC_STDIO \
|
||||
LIBC_LOG \
|
||||
LIBC_STR \
|
||||
LIBC_STUBS \
|
||||
LIBC_SYSV \
|
||||
|
|
65
third_party/ggml/llama.cc
vendored
65
third_party/ggml/llama.cc
vendored
|
@ -1,5 +1,5 @@
|
|||
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│
|
||||
│vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
/*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
|
||||
│vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
╚──────────────────────────────────────────────────────────────────────────────╝
|
||||
│ │
|
||||
│ llama.cpp │
|
||||
|
@ -25,6 +25,30 @@
|
|||
│ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. │
|
||||
│ │
|
||||
╚─────────────────────────────────────────────────────────────────────────────*/
|
||||
#include "third_party/ggml/llama.h"
|
||||
#include "libc/intrin/bits.h"
|
||||
#include "third_party/ggml/ggml.h"
|
||||
#include "third_party/ggml/llama_util.h"
|
||||
#include "third_party/libcxx/algorithm"
|
||||
#include "third_party/libcxx/array"
|
||||
#include "third_party/libcxx/atomic"
|
||||
#include "third_party/libcxx/cassert"
|
||||
#include "third_party/libcxx/cinttypes"
|
||||
#include "third_party/libcxx/climits"
|
||||
#include "third_party/libcxx/cstdint"
|
||||
#include "third_party/libcxx/cstdio"
|
||||
#include "third_party/libcxx/cstring"
|
||||
#include "third_party/libcxx/ctime"
|
||||
#include "third_party/libcxx/fstream"
|
||||
#include "third_party/libcxx/initializer_list"
|
||||
#include "third_party/libcxx/map"
|
||||
#include "third_party/libcxx/memory"
|
||||
#include "third_party/libcxx/mutex"
|
||||
#include "third_party/libcxx/queue"
|
||||
#include "third_party/libcxx/random"
|
||||
#include "third_party/libcxx/sstream"
|
||||
#include "third_party/libcxx/thread"
|
||||
#include "third_party/libcxx/unordered_map"
|
||||
|
||||
asm(".ident\t\"\\n\\n\
|
||||
llama.cpp (MIT License)\\n\
|
||||
|
@ -32,46 +56,9 @@ Copyright (c) 2023 Georgi Gerganov\"");
|
|||
asm(".include \"libc/disclaimer.inc\"");
|
||||
// clang-format off
|
||||
|
||||
// Defines fileno on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#include "third_party/libcxx/cstdint"
|
||||
#include "third_party/libcxx/cstdio"
|
||||
#endif
|
||||
|
||||
#include "third_party/ggml/llama_util.h"
|
||||
#include "third_party/ggml/llama.h"
|
||||
|
||||
#include "third_party/ggml/ggml.h"
|
||||
|
||||
#include "third_party/libcxx/array"
|
||||
#include "third_party/libcxx/ctime"
|
||||
#include "third_party/libcxx/cinttypes"
|
||||
#include "third_party/libcxx/fstream"
|
||||
#include "third_party/libcxx/random"
|
||||
#include "third_party/libcxx/map"
|
||||
#include "third_party/libcxx/unordered_map"
|
||||
#include "third_party/libcxx/queue"
|
||||
#include "third_party/libcxx/cassert"
|
||||
#include "third_party/libcxx/cstring"
|
||||
#include "third_party/libcxx/climits"
|
||||
#include "third_party/libcxx/memory"
|
||||
#include "third_party/libcxx/algorithm"
|
||||
#include "third_party/libcxx/initializer_list"
|
||||
#include "third_party/libcxx/thread"
|
||||
#include "third_party/libcxx/atomic"
|
||||
#include "third_party/libcxx/mutex"
|
||||
#include "third_party/libcxx/sstream"
|
||||
|
||||
#define LLAMA_USE_SCRATCH
|
||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||
|
||||
#define READ32BE(s) \
|
||||
((uint32_t)((const uint8_t *)(s))[0] << 030 | \
|
||||
(uint32_t)((const uint8_t *)(s))[1] << 020 | \
|
||||
(uint32_t)((const uint8_t *)(s))[2] << 010 | \
|
||||
(uint32_t)((const uint8_t *)(s))[3] << 000)
|
||||
|
||||
// available llama models
|
||||
enum e_model {
|
||||
MODEL_UNKNOWN,
|
||||
|
|
358
third_party/ggml/main.cc
vendored
358
third_party/ggml/main.cc
vendored
|
@ -1,5 +1,5 @@
|
|||
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-│
|
||||
│vi: set net ft=c ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
/*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
|
||||
│vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
||||
╚──────────────────────────────────────────────────────────────────────────────╝
|
||||
│ │
|
||||
│ llama.cpp │
|
||||
|
@ -25,6 +25,23 @@
|
|||
│ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. │
|
||||
│ │
|
||||
╚─────────────────────────────────────────────────────────────────────────────*/
|
||||
#include "libc/assert.h"
|
||||
#include "libc/calls/struct/sigaction.h"
|
||||
#include "libc/calls/struct/stat.h"
|
||||
#include "libc/intrin/bits.h"
|
||||
#include "libc/log/log.h"
|
||||
#include "libc/nexgen32e/x86feature.h"
|
||||
#include "libc/sysv/consts/map.h"
|
||||
#include "libc/sysv/consts/msync.h"
|
||||
#include "libc/sysv/consts/o.h"
|
||||
#include "libc/sysv/consts/prot.h"
|
||||
#include "libc/sysv/consts/sig.h"
|
||||
#include "third_party/ggml/common.h"
|
||||
#include "third_party/ggml/llama.h"
|
||||
#include "third_party/ggml/llama_util.h"
|
||||
#include "third_party/libcxx/iostream"
|
||||
#include "third_party/libcxx/string"
|
||||
#include "third_party/libcxx/vector"
|
||||
|
||||
asm(".ident\t\"\\n\\n\
|
||||
llama.cpp (MIT License)\\n\
|
||||
|
@ -32,62 +49,13 @@ Copyright (c) 2023 Georgi Gerganov\"");
|
|||
asm(".include \"libc/disclaimer.inc\"");
|
||||
// clang-format off
|
||||
|
||||
// Defines sigaction on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#endif
|
||||
|
||||
#include "third_party/ggml/common.h"
|
||||
#include "libc/nexgen32e/x86feature.h"
|
||||
#include "third_party/ggml/llama.h"
|
||||
|
||||
#include "third_party/libcxx/cassert"
|
||||
#include "third_party/libcxx/cinttypes"
|
||||
#include "third_party/libcxx/cmath"
|
||||
#include "third_party/libcxx/cstdio"
|
||||
#include "third_party/libcxx/cstring"
|
||||
#include "third_party/libcxx/ctime"
|
||||
#include "third_party/libcxx/fstream"
|
||||
#include "third_party/libcxx/iostream"
|
||||
#include "third_party/libcxx/string"
|
||||
#include "third_party/libcxx/vector"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include "libc/calls/calls.h"
|
||||
#include "libc/calls/sigtimedwait.h"
|
||||
#include "libc/calls/struct/sigaction.h"
|
||||
#include "libc/calls/struct/siginfo.h"
|
||||
#include "libc/sysv/consts/sa.h"
|
||||
#include "libc/sysv/consts/sicode.h"
|
||||
#include "libc/sysv/consts/ss.h"
|
||||
#include "libc/calls/calls.h"
|
||||
#include "libc/calls/weirdtypes.h"
|
||||
#include "libc/runtime/pathconf.h"
|
||||
#include "libc/runtime/runtime.h"
|
||||
#include "libc/runtime/sysconf.h"
|
||||
#include "libc/sysv/consts/f.h"
|
||||
#include "libc/sysv/consts/fileno.h"
|
||||
#include "libc/sysv/consts/o.h"
|
||||
#include "libc/sysv/consts/ok.h"
|
||||
#include "libc/time/time.h"
|
||||
#include "third_party/getopt/getopt.h"
|
||||
#include "third_party/musl/crypt.h"
|
||||
#include "third_party/musl/lockf.h"
|
||||
#elif defined (_WIN32)
|
||||
#include "libc/calls/calls.h"
|
||||
#include "libc/calls/sigtimedwait.h"
|
||||
#include "libc/calls/struct/sigaction.h"
|
||||
#include "libc/calls/struct/siginfo.h"
|
||||
#include "libc/sysv/consts/sa.h"
|
||||
#include "libc/sysv/consts/sicode.h"
|
||||
#include "libc/sysv/consts/ss.h"
|
||||
#endif
|
||||
|
||||
static console_state con_st;
|
||||
static llama_context ** g_ctx;
|
||||
|
||||
static bool is_interacting = false;
|
||||
|
||||
#define EPHEMERAL(fmt) "\r\e[K\033[1;35m" fmt " \033[0m"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
void sigint_handler(int signo) {
|
||||
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
||||
|
@ -102,6 +70,14 @@ void sigint_handler(int signo) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static int CompareTime(struct timespec a, struct timespec b) {
|
||||
int cmp;
|
||||
if (!(cmp = (a.tv_sec > b.tv_sec) - (a.tv_sec < b.tv_sec))) {
|
||||
cmp = (a.tv_nsec > b.tv_nsec) - (a.tv_nsec < b.tv_nsec);
|
||||
}
|
||||
return cmp;
|
||||
}
|
||||
|
||||
static int on_missing_feature(const char *name) {
|
||||
fprintf(stderr, "error: we require %s support in your microprocessor.\n", name);
|
||||
return 1;
|
||||
|
@ -109,6 +85,9 @@ static int on_missing_feature(const char *name) {
|
|||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
ShowCrashReports();
|
||||
|
||||
params.model = "models/llama-7B/ggml-model.bin";
|
||||
|
||||
if (!X86_HAVE(AVX2)) return on_missing_feature("avx2");
|
||||
|
@ -167,6 +146,7 @@ int main(int argc, char ** argv) {
|
|||
//bool is_prime(int n) {)";
|
||||
|
||||
llama_context * ctx;
|
||||
struct stat model_stat;
|
||||
g_ctx = &ctx;
|
||||
|
||||
// load the model
|
||||
|
@ -182,8 +162,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx = llama_init_from_file(params.model.c_str(), lparams, params.verbose);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
|
||||
if (ctx == NULL || stat(params.model.c_str(), &model_stat)) {
|
||||
fprintf(stderr, "%s: failed to load model: %s\n",
|
||||
params.model.c_str(), strerror(errno));
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
@ -327,6 +308,28 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = params.interactive_first;
|
||||
}
|
||||
|
||||
const uint32_t kJtlpMagic = READ32LE("jtlp");
|
||||
const uint32_t kJtlpVersion = 0;
|
||||
|
||||
struct jtlp_header {
|
||||
uint8_t magic[4];
|
||||
uint8_t version[4];
|
||||
uint8_t state_size[8];
|
||||
uint8_t model_dev[8];
|
||||
uint8_t model_ino[8];
|
||||
uint8_t model_mtim_sec[8];
|
||||
uint8_t model_mtim_nsec[8];
|
||||
uint8_t prompt_size[8];
|
||||
};
|
||||
|
||||
enum jtlp_status {
|
||||
kPromptPending,
|
||||
kPromptCompleted,
|
||||
kPromptFinished
|
||||
};
|
||||
|
||||
enum jtlp_status prompt_status = kPromptPending;
|
||||
|
||||
bool is_antiprompt = false;
|
||||
bool input_noecho = !params.verbose;
|
||||
|
||||
|
@ -334,13 +337,146 @@ int main(int argc, char ** argv) {
|
|||
int n_remain = params.n_predict;
|
||||
int n_consumed = 0;
|
||||
|
||||
// the first thing we will do is to output the prompt, so set color accordingly
|
||||
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
// instantly reload prompt if it's cached
|
||||
int fd = open(params.prompt_path.c_str(), O_RDONLY);
|
||||
if (fd != -1) {
|
||||
size_t state_size;
|
||||
size_t prompt_size;
|
||||
struct timespec mtim;
|
||||
struct jtlp_header *header;
|
||||
off_t rc = lseek(fd, 0, SEEK_END);
|
||||
LLAMA_ASSERT(rc != -1);
|
||||
void *map = MAP_FAILED;
|
||||
size_t file_size = rc;
|
||||
if (file_size < sizeof(header)) {
|
||||
fprintf(stderr, "%s: prompt file too small\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
map = mmap(0, file_size, PROT_READ, MAP_SHARED, fd, 0);
|
||||
if (map == MAP_FAILED) {
|
||||
fprintf(stderr, "%s: mmap failed: %s\n",
|
||||
params.prompt_path.c_str(), strerror(errno));
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
header = (struct jtlp_header *)map;
|
||||
// check file format magic
|
||||
if (READ32LE(header->magic) != kJtlpMagic) {
|
||||
fprintf(stderr, "%s: prompt file has wrong magic\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check file format version
|
||||
if (READ32LE(header->version) > kJtlpVersion) {
|
||||
fprintf(stderr, "%s: prompt has future file format version\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check expected state size
|
||||
state_size = llama_get_state_size(ctx);
|
||||
if (READ64LE(header->state_size) != state_size) {
|
||||
if (params.verbose) {
|
||||
fprintf(stderr, "%s: prompt has stale data state size\n",
|
||||
params.prompt_path.c_str());
|
||||
}
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check model device id
|
||||
if (READ64LE(header->model_dev) != model_stat.st_dev) {
|
||||
fprintf(stderr, "%s: prompt is for different model (dev)\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check model inode id
|
||||
if (READ64LE(header->model_ino) != model_stat.st_ino) {
|
||||
fprintf(stderr, "%s: prompt is for different model (ino)\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check model modified timestamp
|
||||
mtim.tv_sec = READ64LE(header->model_mtim_sec);
|
||||
mtim.tv_nsec = READ64LE(header->model_mtim_nsec);
|
||||
if (CompareTime(model_stat.st_mtim, mtim) > 0) {
|
||||
if (params.verbose) {
|
||||
fprintf(stderr, "%s: model file timestamp changed; will reload and regenerate prompt\n",
|
||||
params.prompt_path.c_str());
|
||||
}
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check prompt file size
|
||||
prompt_size = READ64LE(header->prompt_size);
|
||||
if (sizeof(struct jtlp_header) + prompt_size + state_size > file_size) {
|
||||
fprintf(stderr, "%s: prompt file size unexpected\n",
|
||||
params.prompt_path.c_str());
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// check prompt textus
|
||||
if (prompt_size != params.prompt.size() ||
|
||||
memcmp(header + 1, params.prompt.c_str(), prompt_size) != 0) {
|
||||
if (params.verbose) {
|
||||
fprintf(stderr, "%s: prompt text changed; will reload and regenerate\n",
|
||||
params.prompt_path.c_str());
|
||||
}
|
||||
goto CantReloadPrompt;
|
||||
}
|
||||
// read the transformer state
|
||||
llama_set_state_data(ctx, (uint8_t *)(header + 1) + prompt_size);
|
||||
// we're finished loading the prompt file
|
||||
if (params.verbose) {
|
||||
fprintf(stderr, "%s: %s: reloaded previously saved prompt\n",
|
||||
__func__, params.prompt_path.c_str());
|
||||
}
|
||||
// now setup the business logic
|
||||
llama_set_rng_seed(ctx, params.seed);
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(embd_inp[n_consumed++]);
|
||||
}
|
||||
n_past = n_consumed;
|
||||
prompt_status = kPromptFinished;
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
fflush(stdout);
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
last_output += llama_token_to_str(ctx, id);
|
||||
}
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
if (last_output.find(antiprompt.c_str(),
|
||||
last_output.length() - antiprompt.length(),
|
||||
antiprompt.length()) != std::string::npos) {
|
||||
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
printf("%s", antiprompt.c_str());
|
||||
fflush(stdout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
||||
}
|
||||
CantReloadPrompt:
|
||||
if (map != MAP_FAILED) {
|
||||
munmap(map, file_size);
|
||||
}
|
||||
close(fd);
|
||||
}
|
||||
|
||||
if (prompt_status == kPromptPending && params.verbose) {
|
||||
// the first thing we will do is to output the prompt, so set color accordingly
|
||||
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
}
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
if (prompt_status == kPromptPending &&
|
||||
!params.verbose && con_st.use_color) {
|
||||
fprintf(stderr, EPHEMERAL("loading model..."));
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
||||
// performance inference evaluation of scheduled tokens
|
||||
// this loads prompt tokens and it also does prediction
|
||||
if (embd.size() > 0) {
|
||||
// infinite text generation via context swapping
|
||||
// if we run out of context:
|
||||
|
@ -375,11 +511,102 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
n_past += n_eval;
|
||||
if (prompt_status == kPromptPending &&
|
||||
!params.verbose && con_st.use_color && embd_inp.size()) {
|
||||
fprintf(stderr, EPHEMERAL("loading prompt %d%% ..."),
|
||||
(int)(n_consumed / (double)embd_inp.size() * 100));
|
||||
fflush(stderr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
||||
// save prompt to disk atomically as soon as it's finished loading
|
||||
bool was_completed = prompt_status == kPromptCompleted;
|
||||
if (was_completed && !params.prompt_path.empty()) {
|
||||
int fd = -1;
|
||||
int close_rc;
|
||||
uint8_t buf[8];
|
||||
size_t file_size;
|
||||
size_t state_size;
|
||||
std::string tmppath;
|
||||
void *map = MAP_FAILED;
|
||||
struct jtlp_header header;
|
||||
if (!params.verbose && con_st.use_color) {
|
||||
fprintf(stderr, EPHEMERAL("caching prompt..."));
|
||||
fflush(stderr);
|
||||
}
|
||||
state_size = llama_get_state_size(ctx);
|
||||
WRITE32LE(header.magic, kJtlpMagic);
|
||||
WRITE32LE(header.version, kJtlpVersion);
|
||||
WRITE64LE(header.state_size, state_size);
|
||||
WRITE64LE(header.model_dev, model_stat.st_dev);
|
||||
WRITE64LE(header.model_ino, model_stat.st_ino);
|
||||
WRITE64LE(header.model_mtim_sec, model_stat.st_mtim.tv_sec);
|
||||
WRITE64LE(header.model_mtim_nsec, model_stat.st_mtim.tv_nsec);
|
||||
WRITE64LE(header.prompt_size, params.prompt.size());
|
||||
file_size = sizeof(header) + params.prompt.size() + state_size;
|
||||
tmppath.append(params.prompt_path);
|
||||
tmppath.append(".XXXXXX");
|
||||
fd = mkstemp(&tmppath[0]);
|
||||
if (fd == -1) {
|
||||
fprintf(stderr, "%s: mkstemp failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
goto CouldNotSavePrompt;
|
||||
}
|
||||
if (ftruncate(fd, file_size)) {
|
||||
fprintf(stderr, "%s: ftruncate failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
goto CouldNotSavePrompt;
|
||||
}
|
||||
map = mmap(0, file_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
||||
if (map == MAP_FAILED) {
|
||||
fprintf(stderr, "%s: mmap failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
goto CouldNotSavePrompt;
|
||||
}
|
||||
llama_copy_state_data(ctx, (uint8_t *)map + sizeof(header) + params.prompt.size());
|
||||
memcpy((uint8_t *)map + sizeof(header), params.prompt.c_str(), params.prompt.size());
|
||||
memcpy(map, &header, sizeof(header));
|
||||
if (msync(map, file_size, MS_ASYNC) && params.verbose) {
|
||||
fprintf(stderr, "%s: msync failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
}
|
||||
if (munmap(map, file_size) && params.verbose) {
|
||||
fprintf(stderr, "%s: munmap failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
}
|
||||
map = MAP_FAILED;
|
||||
close_rc = close(fd);
|
||||
fd = -1;
|
||||
if (close_rc) {
|
||||
fprintf(stderr, "%s: close failed: %s\n",
|
||||
tmppath.c_str(), strerror(errno));
|
||||
goto CouldNotSavePrompt;
|
||||
}
|
||||
if (rename(tmppath.c_str(), params.prompt_path.c_str())) {
|
||||
fprintf(stderr, "%s -> %s: rename failed: %s\n",
|
||||
tmppath.c_str(), params.prompt_path.c_str(), strerror(errno));
|
||||
goto CouldNotSavePrompt;
|
||||
}
|
||||
tmppath.clear();
|
||||
CouldNotSavePrompt:
|
||||
if (map != MAP_FAILED) munmap(map, file_size);
|
||||
if (fd != -1) close(fd);
|
||||
if (!tmppath.empty()) unlink(tmppath.c_str());
|
||||
}
|
||||
if (was_completed) {
|
||||
if (!params.verbose && con_st.use_color) {
|
||||
fprintf(stderr, EPHEMERAL(""));
|
||||
fflush(stderr);
|
||||
}
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
prompt_status = kPromptFinished;
|
||||
}
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// out of user input, sample next token
|
||||
const int32_t top_k = params.top_k;
|
||||
|
@ -422,17 +649,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// decrement remaining sampling budget
|
||||
--n_remain;
|
||||
|
||||
} else {
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(embd_inp[n_consumed]);
|
||||
++n_consumed;
|
||||
last_n_tokens.push_back(embd_inp[n_consumed++]);
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// we've nearly finished loading the prompt
|
||||
if (prompt_status == kPromptPending &&
|
||||
(int) embd_inp.size() <= n_consumed) {
|
||||
prompt_status = kPromptCompleted;
|
||||
}
|
||||
}
|
||||
|
||||
// checks for reverse prompt
|
||||
|
@ -476,6 +709,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
if (prompt_status == kPromptCompleted) {
|
||||
continue; // avoid reading line before last token loads
|
||||
}
|
||||
|
||||
// reset color to default if we there is no pending user input
|
||||
if (params.verbose && !input_noecho && (int)embd_inp.size() == n_consumed) {
|
||||
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
||||
|
@ -521,6 +758,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
win32_utf8_encode(wline, line);
|
||||
#else
|
||||
fflush(stdout);
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// input stream is bad or EOF received
|
||||
return 0;
|
||||
|
|
Loading…
Reference in a new issue