duo: simplify a little
This commit is contained in:
parent
d52d193e58
commit
f3965704fd
3 changed files with 29 additions and 44 deletions
|
@ -1914,6 +1914,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.cb_eval = params.cb_eval;
|
cparams.cb_eval = params.cb_eval;
|
||||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||||
|
cparams.cb_split_done = params.cb_split_done;
|
||||||
cparams.offload_kqv = !params.no_kv_offload;
|
cparams.offload_kqv = !params.no_kv_offload;
|
||||||
cparams.flash_attn = params.flash_attn;
|
cparams.flash_attn = params.flash_attn;
|
||||||
|
|
||||||
|
|
|
@ -86,6 +86,7 @@ struct gpt_params {
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||||
void * cb_eval_user_data = nullptr;
|
void * cb_eval_user_data = nullptr;
|
||||||
|
ggml_backend_sched_split_done_callback cb_split_done = nullptr;
|
||||||
|
|
||||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||||
|
|
||||||
|
|
|
@ -3,67 +3,50 @@
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
struct speculation_context
|
||||||
|
{
|
||||||
|
llama_tokens speculation;
|
||||||
|
int32_t instance_id;
|
||||||
|
std::mutex mtx;
|
||||||
|
};
|
||||||
|
|
||||||
|
speculation_context spec_ctx;
|
||||||
|
|
||||||
static void split_done_cb(int split)
|
static void split_done_cb(int split)
|
||||||
{
|
{
|
||||||
fprintf(stderr, "split done: %d\n", split);
|
//fprintf(stderr, "split done: %d\n", split);
|
||||||
|
if (split == 1 || split == 2)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
|
||||||
|
spec_ctx.instance_id = 3 - split;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (argc == 1 || argv[1][0] == '-') {
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
|
return 1;
|
||||||
return 1 ;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 2) {
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.model = argv[1];
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 3) {
|
|
||||||
params.prompt = argv[2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.prompt.empty()) {
|
|
||||||
params.prompt = "Hello my name is";
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
|
||||||
model_params.n_gpu_layers = 99;
|
|
||||||
model_params.rpc_servers = "localhost:50052,localhost:50051";
|
|
||||||
|
|
||||||
const int n_len = 128;
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
llama_model * model = nullptr;
|
||||||
|
llama_context * ctx = nullptr;
|
||||||
if (model == NULL) {
|
params.cb_split_done = split_done_cb;
|
||||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize the context
|
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
|
||||||
|
|
||||||
ctx_params.seed = 1234;
|
|
||||||
ctx_params.n_ctx = 2048;
|
|
||||||
ctx_params.n_threads = params.n_threads;
|
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
|
||||||
ctx_params.cb_split_done = split_done_cb;
|
|
||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
|
||||||
|
|
||||||
if (ctx == NULL) {
|
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const int n_len = 128;
|
||||||
std::vector<llama_token> tokens_list;
|
std::vector<llama_token> tokens_list;
|
||||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue