llama : dedup reserve code

This commit is contained in:
Georgi Gerganov 2025-02-10 14:59:51 +02:00
parent 972f91c7d7
commit f9971ef2e1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -7629,30 +7629,6 @@ static int llama_decode_impl(
return -3;
}
// reserve a worst case graph if needed
// TODO: extract to a function
if (lctx.need_reserve) {
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
ggml_backend_sched_reset(lctx.sched.get());
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -7889,30 +7865,8 @@ static int llama_encode_impl(
//batch_manager->prepare(ubatch);
// reserve a worst case graph if needed
// TODO: extract to a function
if (lctx.need_reserve) {
// TODO: extract to a function
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
// TODO: do reserve
GGML_ASSERT(lctx.need_reserve == false);
ggml_backend_sched_reset(lctx.sched.get());
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);