diff --git a/ggml-backend.c b/ggml-backend.c index 07482bedf..970495a4c 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1337,24 +1337,22 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { struct ggml_tensor * t = split->graph.nodes[j0]; + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); + int j1 = j0; // determine the range [j0, j1] of nodes that can be computed together - while (j1 < split->graph.n_nodes - 1) { - // check if the user needs data from this node - if (sched->callback_eval(t, true, sched->callback_eval_user_data)) { - break; - } - + while (!need && j1 < split->graph.n_nodes - 1) { t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); } struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); ggml_backend_graph_compute(split_backend, &gv); - if (sched->callback_eval(t, true, sched->callback_eval_user_data) && // ask - !sched->callback_eval(t, false, sched->callback_eval_user_data)) { // eval + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { break; }