avoid use of ggml_graph_get_tensor
This commit is contained in:
parent
6972e7e90e
commit
e379e8c10b
1 changed files with 14 additions and 8 deletions
22
llama.cpp
22
llama.cpp
|
@ -7484,15 +7484,21 @@ static int llama_decode_internal(
|
|||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, batch);
|
||||
|
||||
// get logits and embeddings
|
||||
struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output");
|
||||
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm");
|
||||
|
||||
// if logits are none we must be doing embeddings
|
||||
if (res == nullptr) {
|
||||
embeddings = ggml_graph_get_tensor(gf, "result_embed");
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
} else if (strcmp(res->name, "result_embed") == 0) {
|
||||
embeddings = res;
|
||||
res = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
GGML_ASSERT(res || embeddings);
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue