llama : more compact state saving and reloading
This commit is contained in:
parent
1fd1918bdc
commit
98914c0ed0
2 changed files with 139 additions and 58 deletions
173
llama.cpp
173
llama.cpp
|
@ -2102,8 +2102,8 @@ struct llama_context {
|
|||
float * logits = nullptr;
|
||||
|
||||
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffer
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the previous batch
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
|
@ -9192,15 +9192,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
||||
GGML_ASSERT(0 <= n_outputs);
|
||||
|
||||
const int32_t n_outputs_max = std::max((uint32_t) n_outputs, lctx.cparams.n_seq_max);
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const auto n_batch = lctx.cparams.n_batch;
|
||||
const auto n_vocab = lctx.model.hparams.n_vocab;
|
||||
const auto n_embd = lctx.model.hparams.n_embd;
|
||||
const int32_t n_outputs_max = std::max((uint32_t) n_outputs, cparams.n_seq_max);
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = hparams.n_vocab;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const int64_t capacity = lctx.output_size;
|
||||
|
||||
const bool has_logits = lctx.cparams.causal_attn;
|
||||
const bool has_embd = lctx.cparams.embeddings;
|
||||
const bool has_logits = cparams.causal_attn;
|
||||
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
|
||||
if (!lctx.output_ids) {
|
||||
// never resized afterwards
|
||||
|
@ -9211,29 +9214,32 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
|||
}
|
||||
// alloc only when more than the current logits capacity is required
|
||||
if (capacity < n_outputs_max) {
|
||||
lctx.output_size = n_outputs_max;
|
||||
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
|
||||
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
|
||||
|
||||
if (lctx.buf_output) {
|
||||
#ifndef NDEBUG
|
||||
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
|
||||
fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0);
|
||||
#endif
|
||||
ggml_backend_buffer_free(lctx.buf_output);
|
||||
lctx.buf_output = nullptr;
|
||||
lctx.logits = nullptr;
|
||||
lctx.embd = nullptr;
|
||||
}
|
||||
{
|
||||
lctx.output_size = n_outputs_max;
|
||||
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
|
||||
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
|
||||
|
||||
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
|
||||
if (lctx.buf_output == nullptr) {
|
||||
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
|
||||
}
|
||||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
|
||||
|
||||
lctx.logits = has_logits ? output_base : nullptr;
|
||||
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
|
||||
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
|
||||
if (lctx.buf_output == nullptr) {
|
||||
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
|
||||
}
|
||||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
|
||||
|
||||
lctx.logits = has_logits ? output_base : nullptr;
|
||||
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
|
||||
}
|
||||
// set all ids as invalid (assume two's complement negative numbers)
|
||||
memset(lctx.output_ids, -1, n_batch*sizeof(int32_t));
|
||||
|
@ -14038,27 +14044,32 @@ void llama_kv_cache_update(struct llama_context * ctx) {
|
|||
|
||||
// Returns the *maximum* size of the state
|
||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||
const auto & cparams = ctx->cparams;
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = LLAMA_MAX_RNG_STATE;
|
||||
const size_t s_n_outputs = sizeof(size_t);
|
||||
// assume worst case for outputs although only currently set ones are serialized
|
||||
const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
|
||||
const size_t s_logits_size = sizeof(size_t);
|
||||
// assume worst case for logits although only currently set ones are serialized
|
||||
const size_t s_logits = ctx->logits_size * sizeof(float);
|
||||
const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embd_size * sizeof(float);
|
||||
const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
|
||||
const size_t s_kv_buf_size = sizeof(size_t);
|
||||
const size_t s_kv_head = sizeof(uint32_t);
|
||||
const size_t s_kv_size = sizeof(uint32_t);
|
||||
const size_t s_kv_used = sizeof(uint32_t);
|
||||
const size_t s_kv = ctx->kv_self.total_size();
|
||||
// TODO: assume the max is more than 1 seq_id per KV cell
|
||||
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
|
||||
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
|
||||
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
|
||||
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_n_outputs
|
||||
+ s_output_pos
|
||||
+ s_logits_size
|
||||
+ s_logits
|
||||
+ s_embedding_size
|
||||
|
@ -14142,25 +14153,60 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
data_ctx->write(rng_str.data(), rng_size);
|
||||
}
|
||||
|
||||
// copy logits
|
||||
// copy outputs
|
||||
{
|
||||
const size_t logits_size = ctx->logits_size;
|
||||
size_t n_outputs = ctx->n_outputs;
|
||||
|
||||
data_ctx->write(&logits_size, sizeof(logits_size));
|
||||
// copy output ids
|
||||
{
|
||||
std::vector<int32_t> output_pos;
|
||||
const size_t n_batch = ctx->cparams.n_batch;
|
||||
const int32_t * output_ids = ctx->output_ids;
|
||||
|
||||
if (logits_size) {
|
||||
data_ctx->write(ctx->logits, logits_size * sizeof(float));
|
||||
output_pos.resize(n_outputs);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch; ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int32_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
if ((size_t) pos >= output_pos.size()) {
|
||||
// TODO: maybe fail here instead
|
||||
LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__);
|
||||
n_outputs = pos + 1;
|
||||
output_pos.resize(n_outputs);
|
||||
}
|
||||
output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
||||
data_ctx->write(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs) {
|
||||
data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy embeddings
|
||||
{
|
||||
const size_t embeddings_size = ctx->embd_size;
|
||||
// copy logits
|
||||
{
|
||||
const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
|
||||
|
||||
data_ctx->write(&logits_size, sizeof(logits_size));
|
||||
|
||||
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||
if (logits_size) {
|
||||
data_ctx->write(ctx->logits, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
if (embeddings_size) {
|
||||
data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
|
||||
// copy embeddings
|
||||
{
|
||||
const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
|
||||
|
||||
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||
|
||||
if (embeddings_size) {
|
||||
data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14257,6 +14303,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
GGML_ASSERT(!rng_ss.fail());
|
||||
}
|
||||
|
||||
// set output ids
|
||||
{
|
||||
size_t n_outputs;
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
|
||||
|
||||
llama_output_reserve(*ctx, n_outputs);
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
|
||||
inp += n_outputs * sizeof(int32_t);
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
||||
int32_t id = output_pos[i];
|
||||
GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
|
||||
ctx->output_ids[id] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set logits
|
||||
{
|
||||
size_t logits_size;
|
||||
|
@ -14277,7 +14345,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
|
||||
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
||||
|
||||
GGML_ASSERT(ctx->embd_size == embeddings_size);
|
||||
GGML_ASSERT(ctx->embd_size >= embeddings_size);
|
||||
|
||||
if (embeddings_size) {
|
||||
memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
|
||||
|
@ -14562,7 +14630,6 @@ void llama_synchronize(struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
float * llama_get_logits(struct llama_context * ctx) {
|
||||
// TODO: assert that really all logits are in the output
|
||||
llama_synchronize(ctx);
|
||||
|
||||
return ctx->logits;
|
||||
|
@ -14570,12 +14637,17 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||
|
||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
GGML_ASSERT(0 <= j);
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// FIXME: check for nullptr
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
if (ctx->logits && 0 <= j && j < ctx->n_outputs) {
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i);
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
|
@ -14586,12 +14658,17 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
|||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
GGML_ASSERT(0 <= j);
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// FIXME: check for nullptr
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
if (ctx->embd && 0 < j && j < ctx->n_outputs) {
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
|
|
24
llama.h
24
llama.h
|
@ -39,7 +39,7 @@
|
|||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 4
|
||||
#define LLAMA_SESSION_VERSION 5
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -674,25 +674,29 @@ extern "C" {
|
|||
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to llama_decode()
|
||||
// WARNING: the following layout is only valid when the batch outputs logits for all tokens
|
||||
// The logits for the last token are stored in the last row
|
||||
// Logits for which llama_batch.logits[i] == 0 are undefined
|
||||
// Rows: n_tokens provided with llama_batch
|
||||
// The logits for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
// in the order they have in the batch.
|
||||
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
||||
// Cols: n_vocab
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Logits for the ith token. Equivalent to:
|
||||
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get all output token embeddings
|
||||
// WARNING: only use when all outputs are requested
|
||||
// shape: [n_tokens*n_embd] (1-dimensional)
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
// in the order they have in the batch.
|
||||
// shape: [n_outputs*n_embd]
|
||||
// Otherwise, returns NULL.
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the ith token
|
||||
// llama_get_embeddings(ctx) + i*n_embd
|
||||
// Get the embeddings for the ith token. Equivalent to:
|
||||
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue