avoid use of ggml_graph_get_tensor

This commit is contained in:
Douglas Hanley 2024-02-11 09:50:18 -06:00
parent 6972e7e90e
commit e379e8c10b

View file

@ -7484,15 +7484,21 @@ static int llama_decode_internal(
ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_cgraph * gf = llama_build_graph(lctx, batch);
// get logits and embeddings // the output is always the last tensor in the graph
struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output"); struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm"); struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
if (strcmp(res->name, "result_output") == 0) {
// if logits are none we must be doing embeddings // the embeddings could be the second to last tensor, or the third to last tensor
if (res == nullptr) { if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = ggml_graph_get_tensor(gf, "result_embed"); embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
} else if (strcmp(res->name, "result_embed") == 0) {
embeddings = res;
res = nullptr;
} else {
GGML_ASSERT(false);
} }
GGML_ASSERT(res || embeddings);
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);