backend : add eval callback (#4935)

* backend : add eval callback

ggml-ci

* backend : group nodes in a single compute when user don't need them

* backend : clean-up the implementation

ggml-ci

* simple : do not perform tensor data copy if not needed

* simple : fix

* simple : no need for ggml_is_contiguous + fix bool parse

* llama : fix callback placement in llama_context_params

* backend : avoid double-ask callback calls

* simple : restore examples, imatrix will serve as a demo
This commit is contained in:
Georgi Gerganov 2024-01-17 18:39:41 +02:00 committed by GitHub
parent c918fe8dca
commit 44a1a4a41a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 64 additions and 2 deletions

View file

@ -1393,6 +1393,9 @@ struct llama_cparams {
bool mul_mat_q;
bool offload_kqv;
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
};
struct llama_layer {
@ -6254,6 +6257,7 @@ static int llama_decode_internal(
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, batch);
@ -9276,6 +9280,8 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.mul_mat_q =*/ true,
@ -9416,6 +9422,9 @@ struct llama_context * llama_new_context_with_model(
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
hparams.n_ctx_train;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;