diff --git a/ggml.c b/ggml.c index 4968f36c2..0e906d0c3 100644 --- a/ggml.c +++ b/ggml.c @@ -16424,7 +16424,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap GGML_ASSERT(plan->n_tasks[i] > 0); } } - } const int n_threads = plan->n_threads; @@ -16491,6 +16490,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap } } +static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + + ggml_graph_compute(&plan, cgraph); + + if (plan.work_data) { + free(plan.work_data); + } +} + void ggml_graph_reset(struct ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * grad = cgraph->grads[i]; @@ -17327,17 +17340,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - { - struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); - if (plan.work_size > 0) { - plan.work_data = malloc(plan.work_size); - GGML_ASSERT(plan.work_data); - } - ggml_graph_compute(&plan, gb); - if (plan.work_data) { - free(plan.work_data); - } - } + ggml_graph_compute_sugar(gb, params.n_threads); opt->adam.fx_prev = ggml_get_f32_1d(f, 0); opt->adam.fx_best = opt->adam.fx_prev; @@ -17418,17 +17421,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - { - struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); - if (plan.work_size > 0) { - plan.work_data = malloc(plan.work_size); - GGML_ASSERT(plan.work_data); - } - ggml_graph_compute(&plan, gb); - if (plan.work_data) { - free(plan.work_data); - } - } + ggml_graph_compute_sugar(gb, params.n_threads); const float fx = ggml_get_f32_1d(f, 0); @@ -17550,17 +17543,7 @@ static enum ggml_opt_result linesearch_backtracking( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - { - struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads); - if (plan.work_size > 0) { - plan.work_data = malloc(plan.work_size); - GGML_ASSERT(plan.work_data); - } - ggml_graph_compute(&plan, gb); - if (plan.work_data) { - free(plan.work_data); - } - } + ggml_graph_compute_sugar(gb, params->n_threads); ggml_opt_get_grad(np, ps, g); @@ -17679,17 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - { - struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); - if (plan.work_size > 0) { - plan.work_data = malloc(plan.work_size); - GGML_ASSERT(plan.work_data); - } - ggml_graph_compute(&plan, gb); - if (plan.work_data) { - free(plan.work_data); - } - } + + ggml_graph_compute_sugar(gb, params.n_threads); ggml_opt_get_grad(np, ps, g);