add static ggml_graph_compute_sugar()

This commit is contained in:
mqy 2023-07-03 20:28:07 +08:00
parent db81f33ef2
commit 2b502c32ca

64
ggml.c
View file

@ -16424,7 +16424,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
GGML_ASSERT(plan->n_tasks[i] > 0);
}
}
}
const int n_threads = plan->n_threads;
@ -16491,6 +16490,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
}
}
static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads);
if (plan.work_size > 0) {
plan.work_data = malloc(plan.work_size);
GGML_ASSERT(plan.work_data);
}
ggml_graph_compute(&plan, cgraph);
if (plan.work_data) {
free(plan.work_data);
}
}
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * grad = cgraph->grads[i];
@ -17327,17 +17340,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
if (plan.work_size > 0) {
plan.work_data = malloc(plan.work_size);
GGML_ASSERT(plan.work_data);
}
ggml_graph_compute(&plan, gb);
if (plan.work_data) {
free(plan.work_data);
}
}
ggml_graph_compute_sugar(gb, params.n_threads);
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
opt->adam.fx_best = opt->adam.fx_prev;
@ -17418,17 +17421,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
if (plan.work_size > 0) {
plan.work_data = malloc(plan.work_size);
GGML_ASSERT(plan.work_data);
}
ggml_graph_compute(&plan, gb);
if (plan.work_data) {
free(plan.work_data);
}
}
ggml_graph_compute_sugar(gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0);
@ -17550,17 +17543,7 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads);
if (plan.work_size > 0) {
plan.work_data = malloc(plan.work_size);
GGML_ASSERT(plan.work_data);
}
ggml_graph_compute(&plan, gb);
if (plan.work_data) {
free(plan.work_data);
}
}
ggml_graph_compute_sugar(gb, params->n_threads);
ggml_opt_get_grad(np, ps, g);
@ -17679,17 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
if (plan.work_size > 0) {
plan.work_data = malloc(plan.work_size);
GGML_ASSERT(plan.work_data);
}
ggml_graph_compute(&plan, gb);
if (plan.work_data) {
free(plan.work_data);
}
}
ggml_graph_compute_sugar(gb, params.n_threads);
ggml_opt_get_grad(np, ps, g);