From 1756c4b5b69d1bd2c36d55f235be2425ec3a138e Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Wed, 22 May 2024 22:42:08 -0500 Subject: [PATCH] find result_norm/result_embd tensors properly; update output allocation logic --- examples/embedding/embedding.cpp | 4 ++-- examples/retrieval/retrieval.cpp | 6 +++--- llama.cpp | 18 ++++++++++++------ 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9d66b5477..659c90245 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -31,8 +31,8 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok } } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { - int n_tokens = tokens.size(); +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { + size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, logit); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index ee43430f0..3501a0eb3 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -87,9 +87,9 @@ static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tok } } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { - int n_tokens = tokens.size(); - for (size_t i = 0; i < tokens.size(); i++) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } diff --git a/llama.cpp b/llama.cpp index 2f906af7d..60d562b00 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7436,11 +7436,17 @@ struct llm_build_context { } struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { - struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1]; - if (strcmp(inp->name, "result_embd") != 0) { - inp = gf->nodes[gf->n_nodes - 2]; - GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found"); + // find result_norm tensor for input + struct ggml_tensor * inp = nullptr; + for (int i = gf->n_nodes - 1; i >= 0; --i) { + inp = gf->nodes[i]; + if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { + break; + } else { + inp = nullptr; + } } + GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); struct ggml_tensor * cur; @@ -12029,8 +12035,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = cparams.causal_attn; - const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings; const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;