llama : fix lctx.n_outputs not being set before building graph
This commit is contained in:
parent
98914c0ed0
commit
705d3937ea
1 changed files with 79 additions and 68 deletions
145
llama.cpp
145
llama.cpp
|
@ -2103,7 +2103,7 @@ struct llama_context {
|
||||||
|
|
||||||
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
||||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch
|
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch
|
||||||
|
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
|
@ -8985,24 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
||||||
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
||||||
|
|
||||||
if (batch.logits) {
|
|
||||||
int32_t n_outputs = 0;
|
int32_t n_outputs = 0;
|
||||||
|
if (batch.logits) {
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
if (batch.logits[i]) {
|
if (batch.logits[i]) {
|
||||||
data[n_outputs++] = i;
|
data[n_outputs++] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lctx.n_outputs = n_outputs;
|
|
||||||
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
data[i] = i;
|
data[i] = i;
|
||||||
}
|
}
|
||||||
lctx.n_outputs = n_tokens;
|
n_outputs = n_tokens;
|
||||||
} else {
|
} else {
|
||||||
// only keep last output
|
// only keep last output
|
||||||
data[0] = n_tokens - 1;
|
data[0] = n_tokens - 1;
|
||||||
lctx.n_outputs = 1;
|
n_outputs = 1;
|
||||||
}
|
}
|
||||||
|
// the graph needs the have been passed the correct number of outputs
|
||||||
|
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(
|
GGML_ASSERT(
|
||||||
|
@ -9202,6 +9203,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const int64_t capacity = lctx.output_size;
|
const int64_t capacity = lctx.output_size;
|
||||||
|
|
||||||
|
// TODO: use a per-batch flag for logits presence instead
|
||||||
const bool has_logits = cparams.causal_attn;
|
const bool has_logits = cparams.causal_attn;
|
||||||
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||||
|
|
||||||
|
@ -9221,10 +9223,11 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
||||||
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
|
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
|
||||||
|
|
||||||
if (lctx.buf_output) {
|
if (lctx.buf_output) {
|
||||||
#ifndef NDEBUG
|
// This doesn't happen often
|
||||||
|
// #ifndef NDEBUG
|
||||||
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
|
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
|
||||||
fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0);
|
||||||
#endif
|
// #endif
|
||||||
ggml_backend_buffer_free(lctx.buf_output);
|
ggml_backend_buffer_free(lctx.buf_output);
|
||||||
lctx.buf_output = nullptr;
|
lctx.buf_output = nullptr;
|
||||||
lctx.logits = nullptr;
|
lctx.logits = nullptr;
|
||||||
|
@ -9246,7 +9249,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
||||||
|
|
||||||
ggml_backend_buffer_clear(lctx.buf_output, 0);
|
ggml_backend_buffer_clear(lctx.buf_output, 0);
|
||||||
|
|
||||||
lctx.n_outputs = n_outputs; // also set in llama_set_inputs() before a batch
|
lctx.n_outputs = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -9325,8 +9328,8 @@ static int llama_decode_internal(
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
int32_t n_logits = 0;
|
int32_t n_outputs = 0;
|
||||||
int32_t n_logits_prev = 0;
|
int32_t n_outputs_prev = 0;
|
||||||
|
|
||||||
const auto n_ubatch = cparams.n_ubatch;
|
const auto n_ubatch = cparams.n_ubatch;
|
||||||
|
|
||||||
|
@ -9338,11 +9341,9 @@ static int llama_decode_internal(
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
if (batch_all.logits) {
|
if (batch_all.logits) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
if (batch_all.logits[i]) {
|
n_outputs += batch_all.logits[i] != 0;
|
||||||
n_logits++;
|
|
||||||
}
|
}
|
||||||
}
|
llama_output_reserve(lctx, n_outputs);
|
||||||
llama_output_reserve(lctx, n_logits);
|
|
||||||
int32_t i_logits = 0;
|
int32_t i_logits = 0;
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
if (batch_all.logits[i]) {
|
if (batch_all.logits[i]) {
|
||||||
|
@ -9350,15 +9351,15 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||||
n_logits = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
llama_output_reserve(lctx, n_logits);
|
llama_output_reserve(lctx, n_outputs);
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
lctx.output_ids[i] = i;
|
lctx.output_ids[i] = i;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// keep last logits only
|
// keep last output only
|
||||||
n_logits = 1;
|
n_outputs = 1;
|
||||||
llama_output_reserve(lctx, n_logits);
|
llama_output_reserve(lctx, n_outputs);
|
||||||
lctx.output_ids[0] = 0;
|
lctx.output_ids[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9377,6 +9378,27 @@ static int llama_decode_internal(
|
||||||
/* .all_seq_id = */ batch_all.all_seq_id,
|
/* .all_seq_id = */ batch_all.all_seq_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// count the outputs in this u_batch
|
||||||
|
{
|
||||||
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
|
if (u_batch.logits) {
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
n_outputs_new += u_batch.logits[i] != 0;
|
||||||
|
}
|
||||||
|
} else if (lctx.logits_all) {
|
||||||
|
n_outputs_new = n_tokens;
|
||||||
|
} else {
|
||||||
|
// keep last output only
|
||||||
|
if (cur_token + n_tokens >= n_tokens_all) {
|
||||||
|
n_outputs_new = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needs to happen before the graph is built
|
||||||
|
lctx.n_outputs = n_outputs_new;
|
||||||
|
}
|
||||||
|
|
||||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
GGML_ASSERT(n_threads > 0);
|
GGML_ASSERT(n_threads > 0);
|
||||||
|
|
||||||
|
@ -9451,18 +9473,26 @@ static int llama_decode_internal(
|
||||||
embd = gf->nodes[gf->n_nodes - 1];
|
embd = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||||
// TODO: graph view to ignore the logits when not needed
|
} else if (cparams.embeddings) {
|
||||||
} else {
|
// the embeddings could be in the second to last tensor, or any of the previous tensors
|
||||||
if (strcmp(res->name, "result_output") == 0) {
|
int i_embd = gf->n_nodes - 2;
|
||||||
// the token embeddings could be the second to last tensor, or any of the previous tensors
|
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
|
||||||
// NOTE: see build_result_output() for an idea of up to how many tensors to skip
|
i_embd = gf->n_nodes - i;
|
||||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0 && i <= 10; ++i) {
|
if (i_embd < 0) { break; }
|
||||||
embd = gf->nodes[gf->n_nodes - i];
|
embd = gf->nodes[i_embd];
|
||||||
}
|
}
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "missing result_output tensor");
|
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
|
||||||
|
if (!cparams.causal_attn) {
|
||||||
|
res = nullptr; // do not extract logits when not needed
|
||||||
|
// skip computing logits
|
||||||
|
// TODO: is this safe?
|
||||||
|
gf->n_nodes = i_embd + 1;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
embd = nullptr; // do not extract embeddings when not needed
|
||||||
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||||
}
|
}
|
||||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
|
||||||
|
@ -9505,38 +9535,22 @@ static int llama_decode_internal(
|
||||||
//}
|
//}
|
||||||
|
|
||||||
// extract logits
|
// extract logits
|
||||||
// TODO: do not compute and extract logits if only embeddings are needed
|
|
||||||
// update the graphs to skip "result_output" if logits are not needed
|
|
||||||
if (res) {
|
if (res) {
|
||||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||||
GGML_ASSERT(backend_res != nullptr);
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
int32_t new_logits = 0;
|
|
||||||
|
|
||||||
if (u_batch.logits) {
|
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
const int32_t n_outputs_new = lctx.n_outputs;
|
||||||
if (u_batch.logits[i]) {
|
|
||||||
new_logits++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (lctx.logits_all) {
|
|
||||||
new_logits += n_tokens;
|
|
||||||
} else {
|
|
||||||
// keep last logits only
|
|
||||||
if (cur_token + n_tokens >= n_tokens_all) {
|
|
||||||
new_logits += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (new_logits) {
|
if (n_outputs_new) {
|
||||||
GGML_ASSERT(new_logits <= n_logits);
|
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
|
||||||
GGML_ASSERT((n_logits_prev+new_logits)*n_vocab <= (int64_t) lctx.logits_size);
|
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||||
ggml_backend_tensor_get_async(backend_res, res, lctx.logits, n_logits_prev*n_vocab*sizeof(float), new_logits*n_vocab*sizeof(float));
|
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||||
n_logits_prev += new_logits;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (cparams.embeddings && embd) {
|
if (embd) {
|
||||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||||
GGML_ASSERT(backend_embd != nullptr);
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
|
@ -9544,17 +9558,13 @@ static int llama_decode_internal(
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
{
|
{
|
||||||
// extract token embeddings
|
// extract token embeddings
|
||||||
auto & embd_out = lctx.embd;
|
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
||||||
|
const int32_t n_outputs_new = lctx.n_outputs;
|
||||||
|
|
||||||
if (u_batch.logits) {
|
if (n_outputs_new) {
|
||||||
//embd_out.resize(n_embd * n_tokens);
|
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
|
||||||
if (u_batch.logits[i] == 0) {
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// FIXME
|
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
@ -9581,6 +9591,7 @@ static int llama_decode_internal(
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
n_outputs_prev += lctx.n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
|
@ -14639,11 +14650,11 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
const int32_t j = ctx->output_ids[i];
|
const int32_t j = ctx->output_ids[i];
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
|
||||||
if (ctx->logits && 0 <= j && j < ctx->n_outputs) {
|
|
||||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i);
|
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
|
||||||
|
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", j, ctx->output_size);
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
#endif
|
#endif
|
||||||
|
@ -14661,7 +14672,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
if (ctx->embd && 0 < j && j < ctx->n_outputs) {
|
if (ctx->embd && 0 < j && (size_t) j < ctx->output_size) {
|
||||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue