From f3965704fd4223d53dc579a6abaac9283a698346 Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Tue, 21 May 2024 22:31:52 -0400 Subject: [PATCH] duo: simplify a little --- common/common.cpp | 1 + common/common.h | 1 + examples/duo/duo.cpp | 71 +++++++++++++++++--------------------------- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index ae11650b4..5dae42814 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1914,6 +1914,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.cb_split_done = params.cb_split_done; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; diff --git a/common/common.h b/common/common.h index a8e5e50e6..2d5772d19 100644 --- a/common/common.h +++ b/common/common.h @@ -86,6 +86,7 @@ struct gpt_params { ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; + ggml_backend_sched_split_done_callback cb_split_done = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index 10a424c5f..5aec3b3f5 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -3,67 +3,50 @@ #include #include +#include #include #include +using llama_tokens = std::vector; + +struct speculation_context +{ + llama_tokens speculation; + int32_t instance_id; + std::mutex mtx; +}; + +speculation_context spec_ctx; + static void split_done_cb(int split) { - fprintf(stderr, "split done: %d\n", split); + //fprintf(stderr, "split done: %d\n", split); + if (split == 1 || split == 2) + { + std::lock_guard guard(spec_ctx.mtx); + spec_ctx.instance_id = 3 - split; + } } int main(int argc, char ** argv) { gpt_params params; - if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); - return 1 ; + if (gpt_params_parse(argc, argv, params) == false) { + return 1; } - if (argc >= 2) { - params.model = argv[1]; + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); } - - if (argc >= 3) { - params.prompt = argv[2]; - } - - if (params.prompt.empty()) { - params.prompt = "Hello my name is"; - } - - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = 99; - model_params.rpc_servers = "localhost:50052,localhost:50051"; - - const int n_len = 128; - llama_backend_init(); llama_numa_init(params.numa); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - // initialize the context - - llama_context_params ctx_params = llama_context_default_params(); - - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; - ctx_params.cb_split_done = split_done_cb; - - llama_context * ctx = llama_new_context_with_model(model, ctx_params); - - if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } + llama_model * model = nullptr; + llama_context * ctx = nullptr; + params.cb_split_done = split_done_cb; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + const int n_len = 128; std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true);