diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 659c90245..b4b73c017 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,25 +17,10 @@ static std::vector split_lines(const std::string & s) { return lines; } -static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { - switch (pooling_type) { - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_NONE: - return true; - case LLAMA_POOLING_TYPE_CLS: - return pos == 0; - case LLAMA_POOLING_TYPE_LAST: - return pos == n_tokens - 1; - default: - GGML_ASSERT(false && "unsupported pooling type"); - } -} - -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - bool logit = needs_logit(pooling_type, i, n_tokens); - llama_batch_add(batch, tokens[i], i, { seq_id }, logit); + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -192,7 +177,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s, pooling_type); + batch_add_seq(batch, inp, s); s += 1; } diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3501a0eb3..eb89d16da 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,25 +73,10 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { - switch (pooling_type) { - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_NONE: - return true; - case LLAMA_POOLING_TYPE_CLS: - return pos == 0; - case LLAMA_POOLING_TYPE_LAST: - return pos == n_tokens - 1; - default: - GGML_ASSERT(false && "unsupported pooling type"); - } -} - -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - bool logit = needs_logit(pooling_type, i, n_tokens); - llama_batch_add(batch, tokens[i], i, { seq_id }, logit); + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -175,7 +160,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", @@ -247,7 +237,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s, pooling_type); + batch_add_seq(batch, inp, s); s += 1; } @@ -270,7 +260,7 @@ int main(int argc, char ** argv) { std::vector query_tokens = llama_tokenize(ctx, query, true); struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); - batch_add_seq(query_batch, query_tokens, 0, pooling_type); + batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); diff --git a/llama.cpp b/llama.cpp index 42be6bed9..ef48e3c24 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -12166,11 +12166,13 @@ static int llama_decode_internal( std::vector> seq_id; // count outputs - if (batch_all.logits) { + if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) { + n_outputs = n_tokens_all; + } else if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs += batch_all.logits[i] != 0; } - } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { + } else if (lctx.logits_all) { n_outputs = n_tokens_all; } else { // keep last output only