From 3e23be7911704f8474e7dcb32424bb043be63b06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Feb 2025 10:17:42 +0200 Subject: [PATCH] context : store graph build function callback ggml-ci --- src/llama-context.cpp | 37 +++++++++++++++++++++++++++++++++---- src/llama-context.h | 8 ++++++-- src/llama.cpp | 4 ++-- 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 37e43213a..1cd168db2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -33,8 +33,12 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } -llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function fn_build_graph_worst) : +llama_context::llama_context( + const llama_model & model, + const llama_context_params & params, + build_graph_callback && cb_build_graph) : model(model), + cb_build_graph(std::move(cb_build_graph)), t_start_us(model.t_start_us), t_load_us (model.t_load_us) { @@ -289,7 +293,7 @@ llama_context::llama_context(const llama_model & model, const llama_context_para llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp); + ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true); // reserve pp graph first so that buffers are only allocated once ggml_backend_sched_reserve(sched.get(), gf_pp); @@ -298,13 +302,13 @@ llama_context::llama_context(const llama_model & model, const llama_context_para // reserve with tg graph to get the number of splits and nodes llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg); + ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true); ggml_backend_sched_reserve(sched.get(), gf_tg); int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); int n_nodes_tg = ggml_graph_n_nodes(gf_tg); // reserve again with pp graph to avoid ggml-alloc reallocations during inference - gf_pp = fn_build_graph_worst(*this, ubatch_pp); + gf_pp = this->cb_build_graph(*this, ubatch_pp, true); if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); throw std::runtime_error("failed to allocate compute buffers"); @@ -475,6 +479,31 @@ struct llama_batch_manager : public llama_batch_manager_i { //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + // reserve a worst case graph if needed + if (lctx.need_reserve) { + LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); + + const auto & cparams = lctx.cparams; + const auto & model = lctx.model; + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(lctx.sched.get()); + if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + + lctx.need_reserve = false; + } + return true; } diff --git a/src/llama-context.h b/src/llama-context.h index 1277645de..5958deaef 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -36,11 +36,13 @@ struct llama_batch_manager_i { // TODO: make implementation details private // TODO: become abstract base class, split the current implementation into different child classes struct llama_context { - // TODO: store the worst-case graph build function and reuse it later + // TODO: tmp until llama-model starts implementing the graph build function + typedef std::function build_graph_callback; + llama_context( const llama_model & model, const llama_context_params & params, - std::function fn_build_graph_worst); + build_graph_callback && cb_build_graph); const struct llama_model & model; @@ -49,6 +51,8 @@ struct llama_context { llama_adapter_cvec cvec; llama_loras loras; + build_graph_callback cb_build_graph; + std::vector backends; std::vector> set_n_threads_fns; diff --git a/src/llama.cpp b/src/llama.cpp index 0ca8070cd..6268249f2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8508,8 +8508,8 @@ struct llama_context * llama_init_from_model( try { // TODO: add logic which llama_context implementation to construct ctx = new llama_context(*model, params, - [](llama_context & lctx, const llama_ubatch & ubatch) { - return llama_build_graph(lctx, ubatch, true); + [](llama_context & lctx, const llama_ubatch & ubatch, bool worst_case) { + return llama_build_graph(lctx, ubatch, worst_case); }); } catch (const std::exception & e) { LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());