llama : fix embedding conditions

This commit is contained in:
Francis Couture-Harpin 2024-03-17 15:23:44 -04:00
parent d0129e8e29
commit 487f89ec2e

View file

@ -8992,7 +8992,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[n_outputs++] = i;
}
}
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
@ -9205,7 +9205,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
if (!lctx.output_ids) {
// never resized afterwards