only use embd output for pooling_type NONE
This commit is contained in:
parent
1756c4b5b6
commit
7c37ae9d29
1 changed files with 3 additions and 3 deletions
|
@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
@ -11811,7 +11811,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
// (!a || b) is a logical implication (a -> b)
|
// (!a || b) is a logical implication (a -> b)
|
||||||
// !hparams.causal_attn -> !cparams.causal_attn
|
// !hparams.causal_attn -> !cparams.causal_attn
|
||||||
(hparams.causal_attn || !cparams.causal_attn) &&
|
(hparams.causal_attn || !cparams.causal_attn) &&
|
||||||
"causal attention with embedding models is not supported"
|
"causal attention is not supported by this model"
|
||||||
);
|
);
|
||||||
|
|
||||||
if (lctx.inp_KQ_mask) {
|
if (lctx.inp_KQ_mask) {
|
||||||
|
@ -12036,7 +12036,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||||
|
|
||||||
// TODO: use a per-batch flag for logits presence instead
|
// TODO: use a per-batch flag for logits presence instead
|
||||||
const bool has_logits = !cparams.embeddings;
|
const bool has_logits = !cparams.embeddings;
|
||||||
const bool has_embd = cparams.embeddings;
|
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||||
|
|
||||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue