llama : dedup reserve code
This commit is contained in:
parent
972f91c7d7
commit
f9971ef2e1
1 changed files with 2 additions and 48 deletions
|
@ -7629,30 +7629,6 @@ static int llama_decode_impl(
|
|||
return -3;
|
||||
}
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
// TODO: extract to a function
|
||||
if (lctx.need_reserve) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
||||
|
@ -7889,30 +7865,8 @@ static int llama_encode_impl(
|
|||
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
// TODO: extract to a function
|
||||
if (lctx.need_reserve) {
|
||||
// TODO: extract to a function
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
// TODO: do reserve
|
||||
GGML_ASSERT(lctx.need_reserve == false);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue