duo: simplify a little

This commit is contained in:
Oleksandr Kuvshynov 2024-05-21 22:31:52 -04:00
parent d52d193e58
commit f3965704fd
3 changed files with 29 additions and 44 deletions

View file

@ -1914,6 +1914,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval; cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;

View file

@ -86,6 +86,7 @@ struct gpt_params {
ggml_backend_sched_eval_callback cb_eval = nullptr; ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr; void * cb_eval_user_data = nullptr;
ggml_backend_sched_split_done_callback cb_split_done = nullptr;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

View file

@ -3,67 +3,50 @@
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
using llama_tokens = std::vector<llama_token>;
struct speculation_context
{
llama_tokens speculation;
int32_t instance_id;
std::mutex mtx;
};
speculation_context spec_ctx;
static void split_done_cb(int split) static void split_done_cb(int split)
{ {
fprintf(stderr, "split done: %d\n", split); //fprintf(stderr, "split done: %d\n", split);
if (split == 1 || split == 2)
{
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
spec_ctx.instance_id = 3 - split;
}
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (gpt_params_parse(argc, argv, params) == false) {
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); return 1;
return 1 ;
} }
if (argc >= 2) { if (params.seed == LLAMA_DEFAULT_SEED) {
params.model = argv[1]; params.seed = time(NULL);
} }
if (argc >= 3) {
params.prompt = argv[2];
}
if (params.prompt.empty()) {
params.prompt = "Hello my name is";
}
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 99;
model_params.rpc_servers = "localhost:50052,localhost:50051";
const int n_len = 128;
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); llama_model * model = nullptr;
llama_context * ctx = nullptr;
if (model == NULL) { params.cb_split_done = split_done_cb;
fprintf(stderr , "%s: error: unable to load model\n" , __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params);
return 1;
}
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.cb_split_done = split_done_cb;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
const int n_len = 128;
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true); tokens_list = ::llama_tokenize(ctx, params.prompt, true);