the cur parameter is missing

This commit is contained in:
gklab 2023-07-31 17:41:34 +08:00
parent a113689571
commit 4d92be8813

View file

@ -1386,7 +1386,7 @@ static bool llama_model_load(
}
}
static struct ggml_cgraph * llama_build_graph(
static struct std::pair<ggml_cgraph *, ggml_tensor *> llama_build_graph(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
@ -1755,7 +1755,7 @@ static struct ggml_cgraph * llama_build_graph(
ggml_free(ctx0);
return gf;
return std::make_pair(gf, cur);
}
// evaluate the transformer
@ -1799,8 +1799,9 @@ static bool llama_eval_internal(
#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_reset(lctx.alloc);
#endif
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
auto result = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
ggml_cgraph * gf = result.first;
ggml_tensor * cur = result.second;
#ifdef LLAMA_USE_ALLOCATOR
ggml_allocr_alloc_graph(lctx.alloc, gf);
@ -3279,7 +3280,9 @@ struct llama_context * llama_new_context_with_model(
int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
int n_past = hparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
auto result = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
ggml_cgraph * gf = result.first;
ggml_tensor * cur = result.second;
// measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;