llama : fix embedding conditions

This commit is contained in:
Francis Couture-Harpin 2024-03-17 15:23:44 -04:00
parent d0129e8e29
commit 487f89ec2e

View file

@ -8992,7 +8992,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[n_outputs++] = i; data[n_outputs++] = i;
} }
} }
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
data[i] = i; data[i] = i;
} }
@ -9205,7 +9205,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
// TODO: use a per-batch flag for logits presence instead // TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn; const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
if (!lctx.output_ids) { if (!lctx.output_ids) {
// never resized afterwards // never resized afterwards