llama : only reserve n_vocab * n_batch at most for logits

llama_decode asserts that only n_batch tokens are passed each call, and
n_ctx is expected to be bigger than n_batch.
This commit is contained in:
David Friehs 2024-01-08 08:54:13 +01:00
parent aee95df8f1
commit 0093dea953

View file

@ -9797,7 +9797,7 @@ struct llama_context * llama_new_context_with_model(
// resized during inference
if (params.logits_all) {
ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
} else {
ctx->logits.reserve(hparams.n_vocab);
}