pass user data

This commit is contained in:
Oleksandr Kuvshynov 2024-05-25 22:10:19 -04:00
parent 534093878b
commit 7c8699add6
7 changed files with 28 additions and 18 deletions

View file

@ -1923,6 +1923,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval = params.cb_eval; cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done; cparams.cb_split_done = params.cb_split_done;
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;

View file

@ -88,6 +88,7 @@ struct gpt_params {
ggml_backend_sched_eval_callback cb_eval = nullptr; ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr; void * cb_eval_user_data = nullptr;
ggml_backend_sched_split_done_callback cb_split_done = nullptr; ggml_backend_sched_split_done_callback cb_split_done = nullptr;
void * cb_split_done_user_data = nullptr;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

View file

@ -38,15 +38,13 @@ struct speculation_context
bool done; bool done;
}; };
speculation_context spec_ctx; static void split_done_cb(int split, void * p_spec_ctx)
// pass void * spec_ctx
static void split_done_cb(int split)
{ {
if (split == 1 || split == 2) if (split == 1 || split == 2)
{ {
std::lock_guard<std::mutex> guard(spec_ctx.mtx); auto * spec_ctx = static_cast<speculation_context*>(p_spec_ctx);
spec_ctx.vacant_id = split - 1; std::lock_guard<std::mutex> guard(spec_ctx->mtx);
spec_ctx->vacant_id = split - 1;
} }
} }
@ -170,7 +168,8 @@ static int speculation(
} }
static int target( static int target(
llama_model * model, llama_model * model,
speculation_context * spec_ctx,
llama_context * ctx, llama_context * ctx,
const llama_tokens& input, const llama_tokens& input,
size_t n_predict) size_t n_predict)
@ -238,8 +237,8 @@ static int target(
} }
{ {
std::lock_guard<std::mutex> _lock(spec_ctx.mtx); std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
auto & spec = spec_ctx.candidate; auto & spec = spec_ctx->candidate;
size_t n_match = 0; size_t n_match = 0;
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++) for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++)
{ {
@ -285,8 +284,8 @@ static int target(
llama_print_timings(ctx); llama_print_timings(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
{ {
std::lock_guard<std::mutex> _lock(spec_ctx.mtx); std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
spec_ctx.done = true; spec_ctx->done = true;
} }
llama_batch_free(batch); llama_batch_free(batch);
@ -306,11 +305,13 @@ int main(int argc, char ** argv) {
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
speculation_context spec_ctx;
// main model and context // main model and context
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
params.cb_split_done = split_done_cb; params.cb_split_done = split_done_cb;
params.cb_split_done_user_data = &spec_ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_tokens input = llama_tokenize(ctx, params.prompt, true); llama_tokens input = llama_tokenize(ctx, params.prompt, true);
@ -333,7 +334,7 @@ int main(int argc, char ** argv) {
std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params); std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params);
std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input); std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input);
target(model, ctx, input, params.n_predict); target(model, &spec_ctx, ctx, input, params.n_predict);
spec_thread.join(); spec_thread.join();
@ -346,4 +347,4 @@ int main(int argc, char ** argv) {
llama_backend_free(); llama_backend_free();
return 0; return 0;
} }

View file

@ -1076,6 +1076,7 @@ struct ggml_backend_sched {
void * callback_eval_user_data; void * callback_eval_user_data;
ggml_backend_sched_split_done_callback callback_split_done; ggml_backend_sched_split_done_callback callback_split_done;
void * callback_split_done_user_data;
// align context_buffer to GGML_MEM_ALIGN // align context_buffer to GGML_MEM_ALIGN
#ifdef _MSC_VER #ifdef _MSC_VER
@ -1713,7 +1714,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// split finished // split finished
if (sched->callback_split_done) { if (sched->callback_split_done) {
sched->callback_split_done(i); sched->callback_split_done(i, sched->callback_split_done_user_data);
} }
} }
@ -1863,8 +1864,9 @@ void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backe
sched->callback_eval_user_data = user_data; sched->callback_eval_user_data = user_data;
} }
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback) { void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data) {
sched->callback_split_done = callback; sched->callback_split_done = callback;
sched->callback_split_done_user_data = user_data;
} }
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {

View file

@ -177,7 +177,7 @@ extern "C" {
// if set will be called when a split is completed computation // if set will be called when a split is completed computation
// is useful for distributed task orchestraction // is useful for distributed task orchestraction
typedef void (*ggml_backend_sched_split_done_callback)(int split); typedef void (*ggml_backend_sched_split_done_callback)(int split, void * user_data);
// Initialize a backend scheduler // Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
@ -208,7 +208,7 @@ extern "C" {
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
// Set a callback to be called for each resulting node during graph compute // Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback); GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data);
// //
// Utils // Utils

View file

@ -1863,6 +1863,7 @@ struct llama_cparams {
void * cb_eval_user_data; void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done; ggml_backend_sched_split_done_callback cb_split_done;
void * cb_split_done_user_data;
}; };
struct llama_layer { struct llama_layer {
@ -11256,7 +11257,7 @@ static int llama_decode_internal(
ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done); ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done, lctx.cparams.cb_split_done_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
@ -15196,6 +15197,7 @@ struct llama_context_params llama_context_default_params() {
/*.cb_eval =*/ nullptr, /*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr, /*.cb_eval_user_data =*/ nullptr,
/*.cb_split_done =*/ nullptr, /*.cb_split_done =*/ nullptr,
/*.cb_split_done_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16,
/*.logits_all =*/ false, /*.logits_all =*/ false,
@ -15408,6 +15410,8 @@ struct llama_context * llama_new_context_with_model(
cparams.cb_eval = params.cb_eval; cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done; cparams.cb_split_done = params.cb_split_done;
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
auto rope_scaling_type = params.rope_scaling_type; auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {

View file

@ -290,6 +290,7 @@ extern "C" {
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done; ggml_backend_sched_split_done_callback cb_split_done;
void * cb_split_done_user_data;
enum ggml_type type_k; // data type for K cache enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache enum ggml_type type_v; // data type for V cache