llama : dedup reserve code
This commit is contained in:
parent
972f91c7d7
commit
f9971ef2e1
1 changed files with 2 additions and 48 deletions
|
@ -7629,30 +7629,6 @@ static int llama_decode_impl(
|
||||||
return -3;
|
return -3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reserve a worst case graph if needed
|
|
||||||
// TODO: extract to a function
|
|
||||||
if (lctx.need_reserve) {
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & model = lctx.model;
|
|
||||||
|
|
||||||
// build worst-case graph
|
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
|
||||||
|
|
||||||
// initialize scheduler with the worst-case graph
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
|
||||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
||||||
}
|
|
||||||
|
|
||||||
lctx.need_reserve = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
@ -7889,30 +7865,8 @@ static int llama_encode_impl(
|
||||||
|
|
||||||
//batch_manager->prepare(ubatch);
|
//batch_manager->prepare(ubatch);
|
||||||
|
|
||||||
// reserve a worst case graph if needed
|
// TODO: do reserve
|
||||||
// TODO: extract to a function
|
GGML_ASSERT(lctx.need_reserve == false);
|
||||||
if (lctx.need_reserve) {
|
|
||||||
// TODO: extract to a function
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & model = lctx.model;
|
|
||||||
|
|
||||||
// build worst-case graph
|
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
|
||||||
|
|
||||||
// initialize scheduler with the worst-case graph
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
|
||||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
||||||
}
|
|
||||||
|
|
||||||
lctx.need_reserve = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue