add static ggml_graph_compute_sugar()
This commit is contained in:
parent
db81f33ef2
commit
2b502c32ca
1 changed files with 19 additions and 45 deletions
64
ggml.c
64
ggml.c
|
@ -16424,7 +16424,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
|
|||
GGML_ASSERT(plan->n_tasks[i] > 0);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const int n_threads = plan->n_threads;
|
||||
|
@ -16491,6 +16490,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
plan.work_data = malloc(plan.work_size);
|
||||
GGML_ASSERT(plan.work_data);
|
||||
}
|
||||
|
||||
ggml_graph_compute(&plan, cgraph);
|
||||
|
||||
if (plan.work_data) {
|
||||
free(plan.work_data);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * grad = cgraph->grads[i];
|
||||
|
@ -17327,17 +17340,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
{
|
||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
plan.work_data = malloc(plan.work_size);
|
||||
GGML_ASSERT(plan.work_data);
|
||||
}
|
||||
ggml_graph_compute(&plan, gb);
|
||||
if (plan.work_data) {
|
||||
free(plan.work_data);
|
||||
}
|
||||
}
|
||||
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||
|
||||
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
||||
opt->adam.fx_best = opt->adam.fx_prev;
|
||||
|
@ -17418,17 +17421,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
{
|
||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
plan.work_data = malloc(plan.work_size);
|
||||
GGML_ASSERT(plan.work_data);
|
||||
}
|
||||
ggml_graph_compute(&plan, gb);
|
||||
if (plan.work_data) {
|
||||
free(plan.work_data);
|
||||
}
|
||||
}
|
||||
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||
|
||||
const float fx = ggml_get_f32_1d(f, 0);
|
||||
|
||||
|
@ -17550,17 +17543,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
{
|
||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
plan.work_data = malloc(plan.work_size);
|
||||
GGML_ASSERT(plan.work_data);
|
||||
}
|
||||
ggml_graph_compute(&plan, gb);
|
||||
if (plan.work_data) {
|
||||
free(plan.work_data);
|
||||
}
|
||||
}
|
||||
ggml_graph_compute_sugar(gb, params->n_threads);
|
||||
|
||||
ggml_opt_get_grad(np, ps, g);
|
||||
|
||||
|
@ -17679,17 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
{
|
||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
plan.work_data = malloc(plan.work_size);
|
||||
GGML_ASSERT(plan.work_data);
|
||||
}
|
||||
ggml_graph_compute(&plan, gb);
|
||||
if (plan.work_data) {
|
||||
free(plan.work_data);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||
|
||||
ggml_opt_get_grad(np, ps, g);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue