context : store graph build function callback
ggml-ci
This commit is contained in:
parent
5d3491e789
commit
3e23be7911
3 changed files with 41 additions and 8 deletions
|
@ -33,8 +33,12 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
|||
return relative_bucket;
|
||||
}
|
||||
|
||||
llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst) :
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
build_graph_callback && cb_build_graph) :
|
||||
model(model),
|
||||
cb_build_graph(std::move(cb_build_graph)),
|
||||
t_start_us(model.t_start_us),
|
||||
t_load_us (model.t_load_us) {
|
||||
|
||||
|
@ -289,7 +293,7 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
|||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp);
|
||||
ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
ggml_backend_sched_reserve(sched.get(), gf_pp);
|
||||
|
@ -298,13 +302,13 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
|||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg);
|
||||
ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true);
|
||||
ggml_backend_sched_reserve(sched.get(), gf_tg);
|
||||
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
gf_pp = fn_build_graph_worst(*this, ubatch_pp);
|
||||
gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
|
@ -475,6 +479,31 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
|||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (lctx.need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,11 +36,13 @@ struct llama_batch_manager_i {
|
|||
// TODO: make implementation details private
|
||||
// TODO: become abstract base class, split the current implementation into different child classes
|
||||
struct llama_context {
|
||||
// TODO: store the worst-case graph build function and reuse it later
|
||||
// TODO: tmp until llama-model starts implementing the graph build function
|
||||
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
|
||||
|
||||
llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst);
|
||||
build_graph_callback && cb_build_graph);
|
||||
|
||||
const struct llama_model & model;
|
||||
|
||||
|
@ -49,6 +51,8 @@ struct llama_context {
|
|||
llama_adapter_cvec cvec;
|
||||
llama_loras loras;
|
||||
|
||||
build_graph_callback cb_build_graph;
|
||||
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
|
|
|
@ -8508,8 +8508,8 @@ struct llama_context * llama_init_from_model(
|
|||
try {
|
||||
// TODO: add logic which llama_context implementation to construct
|
||||
ctx = new llama_context(*model, params,
|
||||
[](llama_context & lctx, const llama_ubatch & ubatch) {
|
||||
return llama_build_graph(lctx, ubatch, true);
|
||||
[](llama_context & lctx, const llama_ubatch & ubatch, bool worst_case) {
|
||||
return llama_build_graph(lctx, ubatch, worst_case);
|
||||
});
|
||||
} catch (const std::exception & e) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue