From e379e8c10b11f20ca6a2fdeb19b45a05a9247128 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Sun, 11 Feb 2024 09:50:18 -0600 Subject: [PATCH] avoid use of ggml_graph_get_tensor --- llama.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index f64571dc8..51c2264db 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7484,15 +7484,21 @@ static int llama_decode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch); - // get logits and embeddings - struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output"); - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm"); - - // if logits are none we must be doing embeddings - if (res == nullptr) { - embeddings = ggml_graph_get_tensor(gf, "result_embed"); + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(res->name, "result_output") == 0) { + // the embeddings could be the second to last tensor, or the third to last tensor + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } + } else if (strcmp(res->name, "result_embed") == 0) { + embeddings = res; + res = nullptr; + } else { + GGML_ASSERT(false); } - GGML_ASSERT(res || embeddings); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);