diff --git a/common/common.cpp b/common/common.cpp index 3e2df6e34..1a992ec38 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1674,6 +1674,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; cparams.defrag_thold = params.defrag_thold; + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); diff --git a/common/common.h b/common/common.h index 99ee90bc3..03abb3b6e 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,9 @@ struct gpt_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; + ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d8cb0a642..97f3ae79e 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -596,24 +596,17 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model_params mparams = llama_model_params_from_gpt_params(params); - - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); - if (model == NULL) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); - return 1; - } - - llama_context_params cparams = llama_context_params_from_gpt_params(params); - // pass the callback to the backend scheduler // it will be executed for each node during the graph computation - cparams.cb_eval = ik_collect_imatrix; - cparams.cb_eval_user_data = NULL; + params.cb_eval = ik_collect_imatrix; + params.cb_eval_user_data = NULL; - llama_context * ctx = llama_new_context_with_model(model, cparams); - if (ctx == NULL) { - fprintf(stderr, "%s: error: unable to create context\n", __func__); + // init + llama_model * model; + llama_context * ctx; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); return 1; }