llama : greatly reduce logits memory usage

This commit is contained in:
Francis Couture-Harpin 2024-03-15 00:46:34 -04:00
parent d01b3c4c32
commit 1fd1918bdc
4 changed files with 378 additions and 78 deletions

View file

@ -132,7 +132,6 @@ int main(int argc, char ** argv) {
llama_context * ctx = NULL;
// load the target model
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
// load the prompts from an external file if there are any

View file

@ -65,7 +65,6 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
params.logits_all = true;
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model

450
llama.cpp
View file

@ -1737,6 +1737,7 @@ struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@ -2054,6 +2055,8 @@ struct llama_context {
ggml_backend_free(backend);
}
free(output_ids);
#ifdef GGML_USE_VULKAN
ggml_vk_free_cpu_assist();
#endif
@ -2094,19 +2097,19 @@ struct llama_context {
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
// decode output (2-dimensional array: [n_tokens][n_vocab])
size_t logits_size = 0;
float * logits = nullptr;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffer
int32_t n_outputs = 0; // number of actually-used outputs in the previous batch
#ifndef NDEBUG
// guard against access to unset logits
std::vector<bool> logits_valid;
#endif
bool logits_all = false;
// embeddings output (2-dimensional array: [n_tokens][n_embd])
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0;
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
// sequence embeddings output (map of [n_embd] vectors)
@ -2124,14 +2127,15 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
struct ggml_tensor * inp_KQ_pos; // F32 [n_kv]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// control vectors
struct llama_control_vector cvec;
@ -5562,7 +5566,8 @@ struct llm_build_context {
const float norm_rms_eps;
const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
const int32_t n_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
@ -5609,6 +5614,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
@ -5753,6 +5759,13 @@ struct llm_build_context {
return lctx.inp_pos;
}
struct ggml_tensor * build_inp_out_ids() {
lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
cb(lctx.inp_out_ids, "inp_out_ids", -1);
ggml_set_input(lctx.inp_out_ids);
return lctx.inp_out_ids;
}
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
@ -5809,6 +5822,8 @@ struct llm_build_context {
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@ -5876,6 +5891,15 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -6055,6 +6079,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -6170,6 +6202,15 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
}
struct ggml_tensor * ffn_inp = cur;
// feed forward
@ -6264,6 +6305,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -6461,6 +6510,14 @@ struct llm_build_context {
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
cb(ffn_inp, "ffn_inp", il);
@ -6550,6 +6607,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -6707,6 +6772,14 @@ struct llm_build_context {
}
cb(cur, "kqv_out", il);
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
@ -6829,6 +6902,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -6927,6 +7008,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -7040,6 +7129,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -7146,6 +7243,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -7258,6 +7363,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -7376,6 +7489,15 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
}
// FF
{
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
@ -7473,6 +7595,15 @@ struct llm_build_context {
cur = attention_norm;
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
@ -7565,6 +7696,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -7665,6 +7804,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -7774,6 +7921,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -7884,6 +8039,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -8007,6 +8170,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
@ -8121,6 +8292,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
@ -8234,6 +8413,14 @@ struct llm_build_context {
cb(cur, "kqv_out", il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@ -8381,6 +8568,16 @@ struct llm_build_context {
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
x = ggml_get_rows(ctx0, x, inp_out_ids);
y = ggml_get_rows(ctx0, y, inp_out_ids);
z = ggml_get_rows(ctx0, z, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
@ -8483,6 +8680,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
if (n_outputs == 0) { return gf; }
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
struct ggml_tensor * attn_out = cur;
// feed-forward network
@ -8773,9 +8978,38 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
{
GGML_ASSERT(lctx.inp_out_ids && "every model type must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
if (batch.logits) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (batch.logits[i]) {
data[n_outputs++] = i;
}
}
lctx.n_outputs = n_outputs;
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
lctx.n_outputs = n_tokens;
} else {
// only keep last output
data[0] = n_tokens - 1;
lctx.n_outputs = 1;
}
}
GGML_ASSERT(
// (!a || b) is a logical implication (a -> b)
// !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
"non-causal attention with generative models is not supported"
"causal attention with embedding models is not supported"
);
if (lctx.inp_KQ_mask) {
@ -8954,6 +9188,62 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
// Only alloc when needed
static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
GGML_ASSERT(0 <= n_outputs);
const int32_t n_outputs_max = std::max((uint32_t) n_outputs, lctx.cparams.n_seq_max);
const auto n_batch = lctx.cparams.n_batch;
const auto n_vocab = lctx.model.hparams.n_vocab;
const auto n_embd = lctx.model.hparams.n_embd;
const int64_t capacity = lctx.output_size;
const bool has_logits = lctx.cparams.causal_attn;
const bool has_embd = lctx.cparams.embeddings;
if (!lctx.output_ids) {
// never resized afterwards
lctx.output_ids = (int32_t *) malloc(n_batch*sizeof(int32_t));
if (lctx.output_ids == nullptr) {
throw std::runtime_error("failed to allocate output_ids buffer");
}
}
// alloc only when more than the current logits capacity is required
if (capacity < n_outputs_max) {
if (lctx.buf_output) {
ggml_backend_buffer_free(lctx.buf_output);
lctx.buf_output = nullptr;
lctx.logits = nullptr;
lctx.embd = nullptr;
}
{
lctx.output_size = n_outputs_max;
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
if (lctx.buf_output == nullptr) {
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
}
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
}
}
// set all ids as invalid (assume two's complement negative numbers)
memset(lctx.output_ids, -1, n_batch*sizeof(int32_t));
ggml_backend_buffer_clear(lctx.buf_output, 0);
lctx.n_outputs = n_outputs; // also set in llama_set_inputs() before a batch
}
static void llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
@ -9029,16 +9319,8 @@ static int llama_decode_internal(
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
auto * logits_out = lctx.logits;
#ifndef NDEBUG
auto & logits_valid = lctx.logits_valid;
logits_valid.clear();
logits_valid.resize(n_tokens_all);
memset(logits_out, 0, lctx.logits_size*sizeof(float));
#endif
int32_t n_logits = 0;
int32_t n_logits_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
@ -9047,6 +9329,33 @@ static int llama_decode_internal(
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
// reserve output buffer
if (batch_all.logits) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.logits[i]) {
n_logits++;
}
}
llama_output_reserve(lctx, n_logits);
int32_t i_logits = 0;
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.logits[i]) {
lctx.output_ids[i] = i_logits++;
}
}
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
n_logits = n_tokens_all;
llama_output_reserve(lctx, n_logits);
for (uint32_t i = 0; i < n_tokens_all; ++i) {
lctx.output_ids[i] = i;
}
} else {
// keep last logits only
n_logits = 1;
llama_output_reserve(lctx, n_logits);
lctx.output_ids[0] = 0;
}
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
@ -9125,20 +9434,26 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
if (!hparams.causal_attn) {
if (lctx.n_outputs == 0) {
// no output
res = nullptr;
embd = nullptr;
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
// TODO: graph view to ignore the logits when not needed
} else {
if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor
if (strcmp(embd->name, "result_norm") != 0) {
embd = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
// the token embeddings could be the second to last tensor, or any of the previous tensors
// NOTE: see build_result_output() for an idea of up to how many tensors to skip
for (int i = 3; strcmp(embd->name, "result_norm") != 0 && i <= 10; ++i) {
embd = gf->nodes[gf->n_nodes - i];
}
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
} else {
GGML_ASSERT(false && "missing result_output tensor");
}
@ -9189,41 +9504,29 @@ static int llama_decode_internal(
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
int32_t new_logits = 0;
if (u_batch.logits) {
int32_t i_first = -1;
for (uint32_t i = 0; i < n_tokens; i++) {
if (u_batch.logits[i] && i_first == -1) {
i_first = (int32_t) i;
if (u_batch.logits[i]) {
new_logits++;
}
if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
if (i_first != -1) {
int i_last = u_batch.logits[i] == 0 ? i : i + 1;
// extract logits for the range [i_first, i_last)
// group the requests to minimize the number of calls to the backend
ggml_backend_tensor_get_async(backend_res, res,
logits_out + n_vocab*(cur_token + i_first),
i_first*n_vocab*sizeof(float),
(i_last - i_first)*n_vocab*sizeof(float));
i_first = -1;
}
}
#ifndef NDEBUG
logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
#endif
}
} else if (lctx.logits_all) {
ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG
std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
#endif
new_logits += n_tokens;
} else {
// keep last logits only
if (cur_token + n_tokens >= n_tokens_all) {
ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG
logits_valid[0] = true;
#endif
new_logits += 1;
}
}
if (new_logits) {
GGML_ASSERT(new_logits <= n_logits);
GGML_ASSERT((n_logits_prev+new_logits)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, lctx.logits, n_logits_prev*n_vocab*sizeof(float), new_logits*n_vocab*sizeof(float));
n_logits_prev += new_logits;
}
}
// extract embeddings
@ -9243,6 +9546,7 @@ static int llama_decode_internal(
if (u_batch.logits[i] == 0) {
continue;
}
// FIXME
ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
}
}
@ -13011,7 +13315,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
// TODO: maybe add n_seq_max here too
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -13214,25 +13518,14 @@ struct llama_context * llama_new_context_with_model(
// graph outputs buffer
{
// resized during inference, reserve maximum
ctx->logits_size = hparams.n_vocab*cparams.n_batch;
ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
if (ctx->buf_output == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
// resized during inference when more than n_seq_max logits are requested in a batch
try {
llama_output_reserve(*ctx, 0);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error reserving logits buffer: %s\n", __func__, err.what());
llama_free(ctx);
return nullptr;
}
ggml_backend_buffer_clear(ctx->buf_output, 0);
ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
if (params.embeddings) {
ctx->embd = ctx->logits + ctx->logits_size;
}
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(ctx->buf_output),
@ -14269,17 +14562,20 @@ void llama_synchronize(struct llama_context * ctx) {
}
float * llama_get_logits(struct llama_context * ctx) {
// TODO: assert that really all logits are in the output
llama_synchronize(ctx);
return ctx->logits;
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
assert(ctx->logits_valid.at(i));
const int32_t j = ctx->output_ids[i];
GGML_ASSERT(0 <= j);
llama_synchronize(ctx);
return ctx->logits + i*ctx->model.hparams.n_vocab;
// FIXME: check for nullptr
return ctx->logits + j*ctx->model.hparams.n_vocab;
}
float * llama_get_embeddings(struct llama_context * ctx) {
@ -14289,9 +14585,13 @@ float * llama_get_embeddings(struct llama_context * ctx) {
}
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
GGML_ASSERT(0 <= j);
llama_synchronize(ctx);
return ctx->embd + i*ctx->model.hparams.n_embd;
// FIXME: check for nullptr
return ctx->embd + j*ctx->model.hparams.n_embd;
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {

View file

@ -674,6 +674,7 @@ extern "C" {
LLAMA_API void llama_synchronize(struct llama_context * ctx);
// Token logits obtained from the last call to llama_decode()
// WARNING: the following layout is only valid when the batch outputs logits for all tokens
// The logits for the last token are stored in the last row
// Logits for which llama_batch.logits[i] == 0 are undefined
// Rows: n_tokens provided with llama_batch
@ -681,10 +682,11 @@ extern "C" {
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. Equivalent to:
// llama_get_logits(ctx) + i*n_vocab
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get all output token embeddings
// WARNING: only use when all outputs are requested
// shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);