pass user data
This commit is contained in:
parent
534093878b
commit
7c8699add6
7 changed files with 28 additions and 18 deletions
|
@ -1923,6 +1923,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.cb_split_done = params.cb_split_done;
|
||||
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ struct gpt_params {
|
|||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||
void * cb_eval_user_data = nullptr;
|
||||
ggml_backend_sched_split_done_callback cb_split_done = nullptr;
|
||||
void * cb_split_done_user_data = nullptr;
|
||||
|
||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||
|
||||
|
|
|
@ -38,15 +38,13 @@ struct speculation_context
|
|||
bool done;
|
||||
};
|
||||
|
||||
speculation_context spec_ctx;
|
||||
|
||||
// pass void * spec_ctx
|
||||
static void split_done_cb(int split)
|
||||
static void split_done_cb(int split, void * p_spec_ctx)
|
||||
{
|
||||
if (split == 1 || split == 2)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
|
||||
spec_ctx.vacant_id = split - 1;
|
||||
auto * spec_ctx = static_cast<speculation_context*>(p_spec_ctx);
|
||||
std::lock_guard<std::mutex> guard(spec_ctx->mtx);
|
||||
spec_ctx->vacant_id = split - 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +169,7 @@ static int speculation(
|
|||
|
||||
static int target(
|
||||
llama_model * model,
|
||||
speculation_context * spec_ctx,
|
||||
llama_context * ctx,
|
||||
const llama_tokens& input,
|
||||
size_t n_predict)
|
||||
|
@ -238,8 +237,8 @@ static int target(
|
|||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
||||
auto & spec = spec_ctx.candidate;
|
||||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||
auto & spec = spec_ctx->candidate;
|
||||
size_t n_match = 0;
|
||||
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++)
|
||||
{
|
||||
|
@ -285,8 +284,8 @@ static int target(
|
|||
llama_print_timings(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
{
|
||||
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
||||
spec_ctx.done = true;
|
||||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||
spec_ctx->done = true;
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
@ -306,11 +305,13 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
speculation_context spec_ctx;
|
||||
|
||||
// main model and context
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
params.cb_split_done = split_done_cb;
|
||||
params.cb_split_done_user_data = &spec_ctx;
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
|
||||
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
|
||||
|
@ -333,7 +334,7 @@ int main(int argc, char ** argv) {
|
|||
std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params);
|
||||
std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input);
|
||||
|
||||
target(model, ctx, input, params.n_predict);
|
||||
target(model, &spec_ctx, ctx, input, params.n_predict);
|
||||
|
||||
spec_thread.join();
|
||||
|
||||
|
|
|
@ -1076,6 +1076,7 @@ struct ggml_backend_sched {
|
|||
void * callback_eval_user_data;
|
||||
|
||||
ggml_backend_sched_split_done_callback callback_split_done;
|
||||
void * callback_split_done_user_data;
|
||||
|
||||
// align context_buffer to GGML_MEM_ALIGN
|
||||
#ifdef _MSC_VER
|
||||
|
@ -1713,7 +1714,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
|
||||
// split finished
|
||||
if (sched->callback_split_done) {
|
||||
sched->callback_split_done(i);
|
||||
sched->callback_split_done(i, sched->callback_split_done_user_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1863,8 +1864,9 @@ void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backe
|
|||
sched->callback_eval_user_data = user_data;
|
||||
}
|
||||
|
||||
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback) {
|
||||
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data) {
|
||||
sched->callback_split_done = callback;
|
||||
sched->callback_split_done_user_data = user_data;
|
||||
}
|
||||
|
||||
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
|
||||
|
|
|
@ -177,7 +177,7 @@ extern "C" {
|
|||
|
||||
// if set will be called when a split is completed computation
|
||||
// is useful for distributed task orchestraction
|
||||
typedef void (*ggml_backend_sched_split_done_callback)(int split);
|
||||
typedef void (*ggml_backend_sched_split_done_callback)(int split, void * user_data);
|
||||
|
||||
// Initialize a backend scheduler
|
||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
|
||||
|
@ -208,7 +208,7 @@ extern "C" {
|
|||
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
||||
|
||||
// Set a callback to be called for each resulting node during graph compute
|
||||
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback);
|
||||
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback, void * user_data);
|
||||
|
||||
//
|
||||
// Utils
|
||||
|
|
|
@ -1863,6 +1863,7 @@ struct llama_cparams {
|
|||
void * cb_eval_user_data;
|
||||
|
||||
ggml_backend_sched_split_done_callback cb_split_done;
|
||||
void * cb_split_done_user_data;
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
|
@ -11256,7 +11257,7 @@ static int llama_decode_internal(
|
|||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done);
|
||||
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done, lctx.cparams.cb_split_done_user_data);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
|
||||
|
||||
|
@ -15196,6 +15197,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.cb_eval =*/ nullptr,
|
||||
/*.cb_eval_user_data =*/ nullptr,
|
||||
/*.cb_split_done =*/ nullptr,
|
||||
/*.cb_split_done_user_data =*/ nullptr,
|
||||
/*.type_k =*/ GGML_TYPE_F16,
|
||||
/*.type_v =*/ GGML_TYPE_F16,
|
||||
/*.logits_all =*/ false,
|
||||
|
@ -15408,6 +15410,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.cb_split_done = params.cb_split_done;
|
||||
cparams.cb_split_done_user_data = params.cb_split_done_user_data;
|
||||
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
||||
|
|
1
llama.h
1
llama.h
|
@ -290,6 +290,7 @@ extern "C" {
|
|||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
ggml_backend_sched_split_done_callback cb_split_done;
|
||||
void * cb_split_done_user_data;
|
||||
|
||||
enum ggml_type type_k; // data type for K cache
|
||||
enum ggml_type type_v; // data type for V cache
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue