context : store graph build function callback

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-02-02 10:17:42 +02:00
parent 5d3491e789
commit 3e23be7911
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 41 additions and 8 deletions

View file

@ -33,8 +33,12 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst) :
llama_context::llama_context(
const llama_model & model,
const llama_context_params & params,
build_graph_callback && cb_build_graph) :
model(model),
cb_build_graph(std::move(cb_build_graph)),
t_start_us(model.t_start_us),
t_load_us (model.t_load_us) {
@ -289,7 +293,7 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp);
ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
// reserve pp graph first so that buffers are only allocated once
ggml_backend_sched_reserve(sched.get(), gf_pp);
@ -298,13 +302,13 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
// reserve with tg graph to get the number of splits and nodes
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg);
ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true);
ggml_backend_sched_reserve(sched.get(), gf_tg);
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
gf_pp = fn_build_graph_worst(*this, ubatch_pp);
gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
throw std::runtime_error("failed to allocate compute buffers");
@ -475,6 +479,31 @@ struct llama_batch_manager : public llama_batch_manager_i {
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (lctx.need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
return true;
}

View file

@ -36,11 +36,13 @@ struct llama_batch_manager_i {
// TODO: make implementation details private
// TODO: become abstract base class, split the current implementation into different child classes
struct llama_context {
// TODO: store the worst-case graph build function and reuse it later
// TODO: tmp until llama-model starts implementing the graph build function
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
llama_context(
const llama_model & model,
const llama_context_params & params,
std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst);
build_graph_callback && cb_build_graph);
const struct llama_model & model;
@ -49,6 +51,8 @@ struct llama_context {
llama_adapter_cvec cvec;
llama_loras loras;
build_graph_callback cb_build_graph;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;

View file

@ -8508,8 +8508,8 @@ struct llama_context * llama_init_from_model(
try {
// TODO: add logic which llama_context implementation to construct
ctx = new llama_context(*model, params,
[](llama_context & lctx, const llama_ubatch & ubatch) {
return llama_build_graph(lctx, ubatch, true);
[](llama_context & lctx, const llama_ubatch & ubatch, bool worst_case) {
return llama_build_graph(lctx, ubatch, worst_case);
});
} catch (const std::exception & e) {
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());