pass user data

This commit is contained in:
Oleksandr Kuvshynov 2024-05-25 22:10:19 -04:00
parent 534093878b
commit 7c8699add6
7 changed files with 28 additions and 18 deletions

View file

@ -1923,6 +1923,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done;
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;

View file

@ -88,6 +88,7 @@ struct gpt_params {
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
ggml_backend_sched_split_done_callback cb_split_done = nullptr;
void * cb_split_done_user_data = nullptr;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

View file

@ -38,15 +38,13 @@ struct speculation_context
bool done;
};
speculation_context spec_ctx;
// pass void * spec_ctx
static void split_done_cb(int split)
static void split_done_cb(int split, void * p_spec_ctx)
{
if (split == 1 || split == 2)
{
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
spec_ctx.vacant_id = split - 1;
auto * spec_ctx = static_cast<speculation_context*>(p_spec_ctx);
std::lock_guard<std::mutex> guard(spec_ctx->mtx);
spec_ctx->vacant_id = split - 1;
}
}
@ -170,7 +168,8 @@ static int speculation(
}
static int target(
llama_model * model,
llama_model * model,
speculation_context * spec_ctx,
llama_context * ctx,
const llama_tokens& input,
size_t n_predict)
@ -238,8 +237,8 @@ static int target(
}
{
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
auto & spec = spec_ctx.candidate;
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
auto & spec = spec_ctx->candidate;
size_t n_match = 0;
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++)
{
@ -285,8 +284,8 @@ static int target(
llama_print_timings(ctx);
fprintf(stderr, "\n");
{
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
spec_ctx.done = true;
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
spec_ctx->done = true;
}
llama_batch_free(batch);
@ -306,11 +305,13 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
speculation_context spec_ctx;
// main model and context
llama_model * model = nullptr;
llama_context * ctx = nullptr;
params.cb_split_done = split_done_cb;
params.cb_split_done_user_data = &spec_ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
@ -333,7 +334,7 @@ int main(int argc, char ** argv) {
std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params);
std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input);
target(model, ctx, input, params.n_predict);
target(model, &spec_ctx, ctx, input, params.n_predict);
spec_thread.join();
@ -346,4 +347,4 @@ int main(int argc, char ** argv) {
llama_backend_free();
return 0;
}
}

View file

@ -1076,6 +1076,7 @@ struct ggml_backend_sched {
void * callback_eval_user_data;
ggml_backend_sched_split_done_callback callback_split_done;
void * callback_split_done_user_data;
// align context_buffer to GGML_MEM_ALIGN
#ifdef _MSC_VER
@ -1713,7 +1714,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// split finished
if (sched->callback_split_done) {
sched->callback_split_done(i);
sched->callback_split_done(i, sched->callback_split_done_user_data);
}
}
@ -1863,8 +1864,9 @@ void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backe
sched->callback_eval_user_data = user_data;
}
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback) {
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data) {
sched->callback_split_done = callback;
sched->callback_split_done_user_data = user_data;
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {

View file

@ -177,7 +177,7 @@ extern "C" {
// if set will be called when a split is completed computation
// is useful for distributed task orchestraction
typedef void (*ggml_backend_sched_split_done_callback)(int split);
typedef void (*ggml_backend_sched_split_done_callback)(int split, void * user_data);
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
@ -208,7 +208,7 @@ extern "C" {
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback);
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data);
//
// Utils

View file

@ -1863,6 +1863,7 @@ struct llama_cparams {
void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done;
void * cb_split_done_user_data;
};
struct llama_layer {
@ -11256,7 +11257,7 @@ static int llama_decode_internal(
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done);
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done, lctx.cparams.cb_split_done_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
@ -15196,6 +15197,7 @@ struct llama_context_params llama_context_default_params() {
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.cb_split_done =*/ nullptr,
/*.cb_split_done_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.logits_all =*/ false,
@ -15408,6 +15410,8 @@ struct llama_context * llama_new_context_with_model(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done;
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {

View file

@ -290,6 +290,7 @@ extern "C" {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done;
void * cb_split_done_user_data;
enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache