llama : fix wrong n_outputs in llama_set_inputs
A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch.
This commit is contained in:
parent
408fcb0f91
commit
e19cb3aeb7
1 changed files with 20 additions and 20 deletions
40
llama.cpp
40
llama.cpp
|
@ -2103,7 +2103,7 @@ struct llama_context {
|
|||
|
||||
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
|
@ -8985,25 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
||||
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
||||
|
||||
int32_t n_outputs = 0;
|
||||
if (batch.logits) {
|
||||
if (lctx.n_outputs == n_tokens) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
} else if (batch.logits) {
|
||||
int32_t n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (batch.logits[i]) {
|
||||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
n_outputs = n_tokens;
|
||||
} else {
|
||||
// the graph needs the have been passed the correct number of outputs
|
||||
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
||||
} else if (lctx.n_outputs == 1) {
|
||||
// only keep last output
|
||||
data[0] = n_tokens - 1;
|
||||
n_outputs = 1;
|
||||
} else {
|
||||
GGML_ASSERT(lctx.n_outputs == 0);
|
||||
}
|
||||
// the graph needs the have been passed the correct number of outputs
|
||||
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
||||
}
|
||||
|
||||
GGML_ASSERT(
|
||||
|
@ -9386,7 +9386,7 @@ static int llama_decode_internal(
|
|||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += u_batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
} else if ((uint32_t) n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
// keep last output only
|
||||
|
@ -14166,7 +14166,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
|
||||
// copy outputs
|
||||
{
|
||||
size_t n_outputs = ctx->n_outputs;
|
||||
// Can't use ctx->n_outputs because it's not for the
|
||||
// entire last batch when n_ubatch is smaller than n_batch
|
||||
size_t n_outputs = 0;
|
||||
|
||||
// copy output ids
|
||||
{
|
||||
|
@ -14174,19 +14176,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
const size_t n_batch = ctx->cparams.n_batch;
|
||||
const int32_t * output_ids = ctx->output_ids;
|
||||
|
||||
output_pos.resize(n_outputs);
|
||||
output_pos.resize(ctx->output_size);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch; ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int32_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
if ((size_t) pos >= output_pos.size()) {
|
||||
// TODO: maybe fail here instead
|
||||
LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__);
|
||||
if ((size_t) pos >= n_outputs) {
|
||||
n_outputs = pos + 1;
|
||||
output_pos.resize(n_outputs);
|
||||
}
|
||||
GGML_ASSERT((size_t) pos < ctx->output_size);
|
||||
output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
@ -14201,7 +14201,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
// copy logits
|
||||
{
|
||||
const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
|
||||
|
||||
|
||||
data_ctx->write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue