llama : fix lctx.n_outputs not being set before building graph
This commit is contained in:
parent
98914c0ed0
commit
705d3937ea
1 changed files with 79 additions and 68 deletions
147
llama.cpp
147
llama.cpp
|
@ -2103,7 +2103,7 @@ struct llama_context {
|
|||
|
||||
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
|
@ -8985,24 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
||||
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
||||
|
||||
int32_t n_outputs = 0;
|
||||
if (batch.logits) {
|
||||
int32_t n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (batch.logits[i]) {
|
||||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
lctx.n_outputs = n_outputs;
|
||||
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
lctx.n_outputs = n_tokens;
|
||||
n_outputs = n_tokens;
|
||||
} else {
|
||||
// only keep last output
|
||||
data[0] = n_tokens - 1;
|
||||
lctx.n_outputs = 1;
|
||||
n_outputs = 1;
|
||||
}
|
||||
// the graph needs the have been passed the correct number of outputs
|
||||
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
||||
}
|
||||
|
||||
GGML_ASSERT(
|
||||
|
@ -9202,6 +9203,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
|||
const auto n_embd = hparams.n_embd;
|
||||
const int64_t capacity = lctx.output_size;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = cparams.causal_attn;
|
||||
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
|
||||
|
@ -9221,10 +9223,11 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
|||
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
|
||||
|
||||
if (lctx.buf_output) {
|
||||
#ifndef NDEBUG
|
||||
// This doesn't happen often
|
||||
// #ifndef NDEBUG
|
||||
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
|
||||
fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0);
|
||||
#endif
|
||||
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0);
|
||||
// #endif
|
||||
ggml_backend_buffer_free(lctx.buf_output);
|
||||
lctx.buf_output = nullptr;
|
||||
lctx.logits = nullptr;
|
||||
|
@ -9246,7 +9249,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
|||
|
||||
ggml_backend_buffer_clear(lctx.buf_output, 0);
|
||||
|
||||
lctx.n_outputs = n_outputs; // also set in llama_set_inputs() before a batch
|
||||
lctx.n_outputs = 0;
|
||||
}
|
||||
|
||||
|
||||
|
@ -9325,8 +9328,8 @@ static int llama_decode_internal(
|
|||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
int32_t n_logits = 0;
|
||||
int32_t n_logits_prev = 0;
|
||||
int32_t n_outputs = 0;
|
||||
int32_t n_outputs_prev = 0;
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
|
@ -9338,11 +9341,9 @@ static int llama_decode_internal(
|
|||
// reserve output buffer
|
||||
if (batch_all.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch_all.logits[i]) {
|
||||
n_logits++;
|
||||
}
|
||||
n_outputs += batch_all.logits[i] != 0;
|
||||
}
|
||||
llama_output_reserve(lctx, n_logits);
|
||||
llama_output_reserve(lctx, n_outputs);
|
||||
int32_t i_logits = 0;
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch_all.logits[i]) {
|
||||
|
@ -9350,15 +9351,15 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
n_logits = n_tokens_all;
|
||||
llama_output_reserve(lctx, n_logits);
|
||||
n_outputs = n_tokens_all;
|
||||
llama_output_reserve(lctx, n_outputs);
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
lctx.output_ids[i] = i;
|
||||
}
|
||||
} else {
|
||||
// keep last logits only
|
||||
n_logits = 1;
|
||||
llama_output_reserve(lctx, n_logits);
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
llama_output_reserve(lctx, n_outputs);
|
||||
lctx.output_ids[0] = 0;
|
||||
}
|
||||
|
||||
|
@ -9377,6 +9378,27 @@ static int llama_decode_internal(
|
|||
/* .all_seq_id = */ batch_all.all_seq_id,
|
||||
};
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (u_batch.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += u_batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
// keep last output only
|
||||
if (cur_token + n_tokens >= n_tokens_all) {
|
||||
n_outputs_new = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
|
@ -9451,18 +9473,26 @@ static int llama_decode_internal(
|
|||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
// TODO: graph view to ignore the logits when not needed
|
||||
} else {
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the token embeddings could be the second to last tensor, or any of the previous tensors
|
||||
// NOTE: see build_result_output() for an idea of up to how many tensors to skip
|
||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0 && i <= 10; ++i) {
|
||||
embd = gf->nodes[gf->n_nodes - i];
|
||||
}
|
||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||
} else {
|
||||
GGML_ASSERT(false && "missing result_output tensor");
|
||||
} else if (cparams.embeddings) {
|
||||
// the embeddings could be in the second to last tensor, or any of the previous tensors
|
||||
int i_embd = gf->n_nodes - 2;
|
||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
|
||||
i_embd = gf->n_nodes - i;
|
||||
if (i_embd < 0) { break; }
|
||||
embd = gf->nodes[i_embd];
|
||||
}
|
||||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
|
||||
|
||||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
|
||||
if (!cparams.causal_attn) {
|
||||
res = nullptr; // do not extract logits when not needed
|
||||
// skip computing logits
|
||||
// TODO: is this safe?
|
||||
gf->n_nodes = i_embd + 1;
|
||||
}
|
||||
} else {
|
||||
embd = nullptr; // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
|
@ -9505,38 +9535,22 @@ static int llama_decode_internal(
|
|||
//}
|
||||
|
||||
// extract logits
|
||||
// TODO: do not compute and extract logits if only embeddings are needed
|
||||
// update the graphs to skip "result_output" if logits are not needed
|
||||
if (res) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
int32_t new_logits = 0;
|
||||
|
||||
if (u_batch.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (u_batch.logits[i]) {
|
||||
new_logits++;
|
||||
}
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
new_logits += n_tokens;
|
||||
} else {
|
||||
// keep last logits only
|
||||
if (cur_token + n_tokens >= n_tokens_all) {
|
||||
new_logits += 1;
|
||||
}
|
||||
}
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (new_logits) {
|
||||
GGML_ASSERT(new_logits <= n_logits);
|
||||
GGML_ASSERT((n_logits_prev+new_logits)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, lctx.logits, n_logits_prev*n_vocab*sizeof(float), new_logits*n_vocab*sizeof(float));
|
||||
n_logits_prev += new_logits;
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (cparams.embeddings && embd) {
|
||||
if (embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
|
@ -9544,17 +9558,13 @@ static int llama_decode_internal(
|
|||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
auto & embd_out = lctx.embd;
|
||||
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (u_batch.logits) {
|
||||
//embd_out.resize(n_embd * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (u_batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
// FIXME
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
|
@ -9581,6 +9591,7 @@ static int llama_decode_internal(
|
|||
} break;
|
||||
}
|
||||
}
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
}
|
||||
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
|
@ -14639,11 +14650,11 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|||
const int32_t j = ctx->output_ids[i];
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
if (ctx->logits && 0 <= j && j < ctx->n_outputs) {
|
||||
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i);
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
|
||||
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", j, ctx->output_size);
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
|
@ -14661,7 +14672,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
|||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
if (ctx->embd && 0 < j && j < ctx->n_outputs) {
|
||||
if (ctx->embd && 0 < j && (size_t) j < ctx->output_size) {
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue