llama : fix running a batch with n_outputs == 0

It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.
This commit is contained in:
Francis Couture-Harpin 2024-03-17 20:41:21 -04:00
parent a57fa7faa4
commit 711b0bcb11

View file

@ -8979,7 +8979,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
if (lctx.n_outputs > 0 && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;