llama : fix embedding conditions
This commit is contained in:
parent
d0129e8e29
commit
487f89ec2e
1 changed files with 2 additions and 2 deletions
|
@ -8992,7 +8992,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
@ -9205,7 +9205,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
|
|||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = cparams.causal_attn;
|
||||
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
|
||||
if (!lctx.output_ids) {
|
||||
// never resized afterwards
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue