diff --git a/common/common.cpp b/common/common.cpp index 680f06990..ccc253ca3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1923,6 +1923,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_split_done = params.cb_split_done; + cparams.cb_split_done_user_data = params.cb_split_done_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; diff --git a/common/common.h b/common/common.h index 62b7b05e3..f17fd8075 100644 --- a/common/common.h +++ b/common/common.h @@ -88,6 +88,7 @@ struct gpt_params { ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; ggml_backend_sched_split_done_callback cb_split_done = nullptr; + void * cb_split_done_user_data = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index 6ac2d9f09..ff65df739 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -38,15 +38,13 @@ struct speculation_context bool done; }; -speculation_context spec_ctx; - -// pass void * spec_ctx -static void split_done_cb(int split) +static void split_done_cb(int split, void * p_spec_ctx) { if (split == 1 || split == 2) { - std::lock_guard guard(spec_ctx.mtx); - spec_ctx.vacant_id = split - 1; + auto * spec_ctx = static_cast(p_spec_ctx); + std::lock_guard guard(spec_ctx->mtx); + spec_ctx->vacant_id = split - 1; } } @@ -170,7 +168,8 @@ static int speculation( } static int target( - llama_model * model, + llama_model * model, + speculation_context * spec_ctx, llama_context * ctx, const llama_tokens& input, size_t n_predict) @@ -238,8 +237,8 @@ static int target( } { - std::lock_guard _lock(spec_ctx.mtx); - auto & spec = spec_ctx.candidate; + std::lock_guard _lock(spec_ctx->mtx); + auto & spec = spec_ctx->candidate; size_t n_match = 0; for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++) { @@ -285,8 +284,8 @@ static int target( llama_print_timings(ctx); fprintf(stderr, "\n"); { - std::lock_guard _lock(spec_ctx.mtx); - spec_ctx.done = true; + std::lock_guard _lock(spec_ctx->mtx); + spec_ctx->done = true; } llama_batch_free(batch); @@ -306,11 +305,13 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + speculation_context spec_ctx; // main model and context llama_model * model = nullptr; llama_context * ctx = nullptr; params.cb_split_done = split_done_cb; + params.cb_split_done_user_data = &spec_ctx; std::tie(model, ctx) = llama_init_from_gpt_params(params); llama_tokens input = llama_tokenize(ctx, params.prompt, true); @@ -333,7 +334,7 @@ int main(int argc, char ** argv) { std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params); std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input); - target(model, ctx, input, params.n_predict); + target(model, &spec_ctx, ctx, input, params.n_predict); spec_thread.join(); @@ -346,4 +347,4 @@ int main(int argc, char ** argv) { llama_backend_free(); return 0; -} +} \ No newline at end of file diff --git a/ggml-backend.c b/ggml-backend.c index bb042932a..77b2b43c5 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1076,6 +1076,7 @@ struct ggml_backend_sched { void * callback_eval_user_data; ggml_backend_sched_split_done_callback callback_split_done; + void * callback_split_done_user_data; // align context_buffer to GGML_MEM_ALIGN #ifdef _MSC_VER @@ -1713,7 +1714,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // split finished if (sched->callback_split_done) { - sched->callback_split_done(i); + sched->callback_split_done(i, sched->callback_split_done_user_data); } } @@ -1863,8 +1864,9 @@ void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backe sched->callback_eval_user_data = user_data; } -void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback) { +void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data) { sched->callback_split_done = callback; + sched->callback_split_done_user_data = user_data; } int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { diff --git a/ggml-backend.h b/ggml-backend.h index ff6d3967c..b9f5b70f5 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -177,7 +177,7 @@ extern "C" { // if set will be called when a split is completed computation // is useful for distributed task orchestraction - typedef void (*ggml_backend_sched_split_done_callback)(int split); + typedef void (*ggml_backend_sched_split_done_callback)(int split, void * user_data); // Initialize a backend scheduler GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); @@ -208,7 +208,7 @@ extern "C" { GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); // Set a callback to be called for each resulting node during graph compute - GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback); + GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data); // // Utils diff --git a/llama.cpp b/llama.cpp index 2121f86ff..91a142c38 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1863,6 +1863,7 @@ struct llama_cparams { void * cb_eval_user_data; ggml_backend_sched_split_done_callback cb_split_done; + void * cb_split_done_user_data; }; struct llama_layer { @@ -11256,7 +11257,7 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done); + ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done, lctx.cparams.cb_split_done_user_data); ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); @@ -15196,6 +15197,7 @@ struct llama_context_params llama_context_default_params() { /*.cb_eval =*/ nullptr, /*.cb_eval_user_data =*/ nullptr, /*.cb_split_done =*/ nullptr, + /*.cb_split_done_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.logits_all =*/ false, @@ -15408,6 +15410,8 @@ struct llama_context * llama_new_context_with_model( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_split_done = params.cb_split_done; + cparams.cb_split_done_user_data = params.cb_split_done_user_data; + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { diff --git a/llama.h b/llama.h index ab6f07d2a..a485fff15 100644 --- a/llama.h +++ b/llama.h @@ -290,6 +290,7 @@ extern "C" { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; ggml_backend_sched_split_done_callback cb_split_done; + void * cb_split_done_user_data; enum ggml_type type_k; // data type for K cache enum ggml_type type_v; // data type for V cache